Merge branch 'develop' into fix/prajjawal-9443

This commit is contained in:
Prajjawal Agarwal 2023-08-23 04:22:28 -04:00 committed by GitHub
commit 44e2a2fc57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 339 additions and 116 deletions

8
Cargo.lock generated
View File

@ -332,18 +332,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.183" version = "1.0.184"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" checksum = "2c911f4b04d7385c9035407a4eff5903bf4fe270fa046fda448b69e797f4fff0"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.183" version = "1.0.184"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" checksum = "c1df27f5b29406ada06609b2e2f77fb34f6dbb104a457a671cc31dbed237e09e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

1
changelog.d/16125.misc Normal file
View File

@ -0,0 +1 @@
Add an admin endpoint to allow authorizing server to signal token revocations.

1
changelog.d/16127.bugfix Normal file
View File

@ -0,0 +1 @@
User constent and 3-PID changes capability cannot be enabled when using experimental [MSC3861](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) support.

1
changelog.d/16134.bugfix Normal file
View File

@ -0,0 +1 @@
User constent and 3-PID changes capability cannot be enabled when using experimental [MSC3861](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) support.

1
changelog.d/16148.bugfix Normal file
View File

@ -0,0 +1 @@
Fix performance degredation when there are a lot of in-flight replication requests.

1
changelog.d/16150.misc Normal file
View File

@ -0,0 +1 @@
Clean-up calling `setup_background_tasks` in unit tests.

1
changelog.d/16152.misc Normal file
View File

@ -0,0 +1 @@
Raised the poetry-core version cap to 1.7.0.

1
changelog.d/16157.misc Normal file
View File

@ -0,0 +1 @@
Fix assertion in user directory unit tests.

1
changelog.d/16158.misc Normal file
View File

@ -0,0 +1 @@
Improve presence tests.

View File

@ -367,7 +367,7 @@ furo = ">=2022.12.7,<2024.0.0"
# system changes. # system changes.
# We are happy to raise these upper bounds upon request, # We are happy to raise these upper bounds upon request,
# provided we check that it's safe to do so (i.e. that CI passes). # provided we check that it's safe to do so (i.e. that CI passes).
requires = ["poetry-core>=1.1.0,<=1.6.0", "setuptools_rust>=1.3,<=1.6.0"] requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.6.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -438,3 +438,16 @@ class MSC3861DelegatedAuth(BaseAuth):
scope=scope, scope=scope,
is_guest=(has_guest_scope and not has_user_scope), is_guest=(has_guest_scope and not has_user_scope),
) )
def invalidate_cached_tokens(self, keys: List[str]) -> None:
"""
Invalidate the entry(s) in the introspection token cache corresponding to the given key
"""
for key in keys:
self._token_cache.invalidate(key)
def invalidate_token_cache(self) -> None:
"""
Invalidate the entire token cache.
"""
self._token_cache.invalidate_all()

View File

@ -173,6 +173,13 @@ class MSC3861:
("enable_registration",), ("enable_registration",),
) )
# We only need to test the user consent version, as if it must be set if the user_consent section was present in the config
if root.consent.user_consent_version is not None:
raise ConfigError(
"User consent cannot be enabled when OAuth delegation is enabled",
("user_consent",),
)
if ( if (
root.oidc.oidc_enabled root.oidc.oidc_enabled
or root.saml2.saml2_enabled or root.saml2.saml2_enabled
@ -216,6 +223,12 @@ class MSC3861:
("session_lifetime",), ("session_lifetime",),
) )
if root.registration.enable_3pid_changes:
raise ConfigError(
"enable_3pid_changes cannot be enabled when OAuth delegation is enabled",
("enable_3pid_changes",),
)
@attr.s(auto_attribs=True, frozen=True, slots=True) @attr.s(auto_attribs=True, frozen=True, slots=True)
class MSC3866Config: class MSC3866Config:

View File

@ -133,7 +133,16 @@ class RegistrationConfig(Config):
self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
# The default value of enable_3pid_changes is True, unless msc3861 is enabled.
msc3861_enabled = (
(config.get("experimental_features") or {})
.get("msc3861", {})
.get("enabled", False)
)
self.enable_3pid_changes = config.get(
"enable_3pid_changes", not msc3861_enabled
)
self.disable_msisdn_registration = config.get( self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False "disable_msisdn_registration", False

View File

@ -14,7 +14,9 @@
"""A replication client for use by synapse workers. """A replication client for use by synapse workers.
""" """
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
from sortedcontainers import SortedList
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -26,6 +28,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
AccountDataStream, AccountDataStream,
CachesStream,
DeviceListsStream, DeviceListsStream,
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
@ -73,6 +76,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler() self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
self.auth = hs.get_auth()
self._notify_pushers = hs.config.worker.start_pushers self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool() self._pusher_pool = hs.get_pusherpool()
@ -84,7 +88,9 @@ class ReplicationDataHandler:
# Map from stream and instance to list of deferreds waiting for the stream to # Map from stream and instance to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position. # arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} self._streams_to_waiters: Dict[
Tuple[str, str], SortedList[Tuple[int, Deferred]]
] = {}
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
@ -218,6 +224,16 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated( self._state_storage_controller.notify_event_un_partial_stated(
row.event_id row.event_id
) )
# invalidate the introspection token cache
elif stream_name == CachesStream.NAME:
for row in rows:
if row.cache_func == "introspection_token_invalidation":
if row.keys[0] is None:
# invalidate the whole cache
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth.invalidate_token_cache() # type: ignore[attr-defined]
else:
self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]
await self._presence_handler.process_replication_rows( await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
@ -226,7 +242,9 @@ class ReplicationDataHandler:
# Notify any waiting deferreds. The list is ordered by position so we # Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is # just iterate through the list until we reach a position that is
# greater than the received row position. # greater than the received row position.
waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) waiting_list = self._streams_to_waiters.get((stream_name, instance_name))
if not waiting_list:
return
# Index of first item with a position after the current token, i.e we # Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by # have called all deferreds before this index. If not overwritten by
@ -250,7 +268,7 @@ class ReplicationDataHandler:
# Drop all entries in the waiting list that were called in the above # Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort) # loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] del waiting_list[:index_of_first_deferred_not_called]
for deferred in deferreds_to_callback: for deferred in deferreds_to_callback:
try: try:
@ -310,11 +328,10 @@ class ReplicationDataHandler:
) )
waiting_list = self._streams_to_waiters.setdefault( waiting_list = self._streams_to_waiters.setdefault(
(stream_name, instance_name), [] (stream_name, instance_name), SortedList(key=lambda t: t[0])
) )
waiting_list.append((position, deferred)) waiting_list.add((position, deferred))
waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time. # We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"): with Measure(self._clock, "repl.wait_for_stream_position"):

View File

@ -47,6 +47,7 @@ from synapse.rest.admin.federation import (
ListDestinationsRestServlet, ListDestinationsRestServlet,
) )
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import ( from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet, ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet, NewRegistrationTokenRestServlet,
@ -297,6 +298,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server)
if hs.config.experimental.msc3861.enabled:
OIDCTokenRevocationRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(

View File

@ -0,0 +1,55 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Tuple
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
if TYPE_CHECKING:
from synapse.server import HomeServer
class OIDCTokenRevocationRestServlet(RestServlet):
"""
Delete a given token introspection response - identified by the `jti` field - from the
introspection token cache when a token is revoked at the authorizing server
"""
PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__()
auth = hs.get_auth()
# If this endpoint is loaded then we must have enabled delegated auth.
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
assert isinstance(auth, MSC3861DelegatedAuth)
self.auth = auth
self.store = hs.get_datastores().main
async def on_DELETE(
self, request: SynapseRequest, token_id: str
) -> Tuple[HTTPStatus, Dict]:
await assert_requester_is_admin(self.auth, request)
self.auth._token_cache.invalidate(token_id)
# make sure we invalidate the cache on any workers
await self.store.stream_introspection_token_invalidation((token_id,))
return HTTPStatus.OK, {}

View File

@ -584,6 +584,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
return 0 return 0
async def stream_introspection_token_invalidation(
self, key: Tuple[Optional[str]]
) -> None:
"""
Stream an invalidation request for the introspection token cache to workers
Args:
key: token_id of the introspection token to remove from the cache
"""
await self.send_invalidation_to_replication(
"introspection_token_invalidation", key
)
@wrap_as_background_process("clean_up_old_cache_invalidations") @wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None: async def _clean_up_cache_invalidation_wrapper(self) -> None:
""" """

View File

@ -33,6 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
get_active_span_text_map, get_active_span_text_map,
set_tag, set_tag,
@ -1663,6 +1664,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache: LruCache[ self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True] Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000) ] = LruCache(cache_name="device_id_exists", max_size=10000)
self.config: HomeServerConfig = hs.config
async def store_device( async def store_device(
self, self,
@ -1784,6 +1786,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids: for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))
# TODO: don't nuke the entire cache once there is a way to associate
# device_id -> introspection_token
if self.config.experimental.msc3861.enabled:
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
await self.stream_introspection_token_invalidation((None,))
async def update_device( async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None: ) -> None:

View File

@ -140,6 +140,20 @@ class ExpiringCache(Generic[KT, VT]):
return value.value return value.value
def invalidate(self, key: KT) -> None:
"""
Remove the given key from the cache.
"""
value = self._cache.pop(key, None)
if value:
if self.iterable:
self.metrics.inc_evictions(
EvictionReason.invalidation, len(value.value)
)
else:
self.metrics.inc_evictions(EvictionReason.invalidation)
def __contains__(self, key: KT) -> bool: def __contains__(self, key: KT) -> bool:
return key in self._cache return key in self._cache
@ -193,6 +207,14 @@ class ExpiringCache(Generic[KT, VT]):
len(self), len(self),
) )
def invalidate_all(self) -> None:
"""
Remove all items from the cache.
"""
keys = set(self._cache.keys())
for key in keys:
self._cache.pop(key)
def __len__(self) -> int: def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) return sum(len(entry.value) for entry in self._cache.values())

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from unittest.mock import Mock from unittest.mock import Mock
from synapse.config import ConfigError from synapse.config import ConfigError
@ -167,6 +168,21 @@ class MSC3861OAuthDelegation(TestCase):
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
self.parse_config() self.parse_config()
def test_user_consent_cannot_be_enabled(self) -> None:
tmpdir = self.mktemp()
os.mkdir(tmpdir)
self.config_dict["user_consent"] = {
"require_at_registration": True,
"version": "1",
"template_dir": tmpdir,
"server_notice_content": {
"msgtype": "m.text",
"body": "foo",
},
}
with self.assertRaises(ConfigError):
self.parse_config()
def test_password_config_cannot_be_enabled(self) -> None: def test_password_config_cannot_be_enabled(self) -> None:
self.config_dict["password_config"] = {"enabled": True} self.config_dict["password_config"] = {"enabled": True}
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
@ -255,3 +271,8 @@ class MSC3861OAuthDelegation(TestCase):
self.config_dict["session_lifetime"] = "24h" self.config_dict["session_lifetime"] = "24h"
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
self.parse_config() self.parse_config()
def test_enable_3pid_changes_cannot_be_enabled(self) -> None:
self.config_dict["enable_3pid_changes"] = True
with self.assertRaises(ConfigError):
self.parse_config()

View File

@ -14,7 +14,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, Union from typing import Any, Dict, Union
from unittest.mock import ANY, Mock from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs from urllib.parse import parse_qs
from signedjson.key import ( from signedjson.key import (
@ -588,6 +588,38 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
) )
self.assertEqual(self.http_client.request.call_count, 2) self.assertEqual(self.http_client.request.call_count, 2)
def test_revocation_endpoint(self) -> None:
# mock introspection response and then admin verification response
self.http_client.request = AsyncMock(
side_effect=[
FakeResponse.json(
code=200, payload={"active": True, "jti": "open_sesame"}
),
FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
),
]
)
# cache a token to delete
introspection_token = self.get_success(
self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
)
self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]
# delete the revoked token
introspection_token_id = "open_sesame"
url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
channel = self.make_request("DELETE", url, access_token="mockAccessToken")
self.assertEqual(channel.code, 200)
self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test. # We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id) master_signing_key = generate_signing_key(device_id)

View File

@ -514,6 +514,9 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id = "@test:server"
user_id_obj = UserID.from_string(user_id)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -523,12 +526,11 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
we time out their syncing users presence. we time out their syncing users presence.
""" """
process_id = "1" process_id = "1"
user_id = "@test:server"
# Notify handler that a user is now syncing. # Notify handler that a user is now syncing.
self.get_success( self.get_success(
self.presence_handler.update_external_syncs_row( self.presence_handler.update_external_syncs_row(
process_id, user_id, True, self.clock.time_msec() process_id, self.user_id, True, self.clock.time_msec()
) )
) )
@ -536,48 +538,37 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# stopped syncing that their presence state doesn't get timed out. # stopped syncing that their presence state doesn't get timed out.
self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2) self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
# Check that if the external process timeout fires, then the syncing # Check that if the external process timeout fires, then the syncing
# user gets timed out # user gets timed out
self.reactor.advance(EXTERNAL_PROCESS_EXPIRY) self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None: def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while """Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains. users presence goes `OFFLINE` because of timeout and `status_msg` remains.
""" """
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
# Mark user as online # Mark user as online
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
user_id, PresenceState.ONLINE, status_msg
)
# Check that if we wait a while without telling the handler the user has # Check that if we wait a while without telling the handler the user has
# stopped syncing that their presence state doesn't get timed out. # stopped syncing that their presence state doesn't get timed out.
self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2)
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, status_msg) self.assertEqual(state.status_msg, status_msg)
# Check that if the timeout fires, then the syncing user gets timed out # Check that if the timeout fires, then the syncing user gets timed out
self.reactor.advance(SYNC_ONLINE_TIMEOUT) self.reactor.advance(SYNC_ONLINE_TIMEOUT)
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# status_msg should remain even after going offline # status_msg should remain even after going offline
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg) self.assertEqual(state.status_msg, status_msg)
@ -586,24 +577,19 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user change presence manually to `OFFLINE` """Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`. and no status is set, that `status_msg` is `None`.
""" """
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
# Mark user as online # Mark user as online
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
user_id, PresenceState.ONLINE, status_msg
)
# Mark user as offline # Mark user as offline
self.get_success( self.get_success(
self.presence_handler.set_state( self.presence_handler.set_state(
UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} self.user_id_obj, {"presence": PresenceState.OFFLINE}
) )
) )
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None) self.assertEqual(state.status_msg, None)
@ -611,41 +597,31 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user change presence manually to `OFFLINE` """Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears. and a status is set, that `status_msg` appears.
""" """
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
# Mark user as online # Mark user as online
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
user_id, PresenceState.ONLINE, status_msg
)
# Mark user as offline # Mark user as offline
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.OFFLINE, "And now here.")
user_id, PresenceState.OFFLINE, "And now here."
)
def test_user_reset_online_with_no_status(self) -> None: def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually """Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`. and no status is set, that `status_msg` is `None`.
""" """
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
# Mark user as online # Mark user as online
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
user_id, PresenceState.ONLINE, status_msg
)
# Mark user as online again # Mark user as online again
self.get_success( self.get_success(
self.presence_handler.set_state( self.presence_handler.set_state(
UserID.from_string(user_id), {"presence": PresenceState.ONLINE} self.user_id_obj, {"presence": PresenceState.ONLINE}
) )
) )
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# status_msg should remain even after going offline # status_msg should remain even after going offline
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None) self.assertEqual(state.status_msg, None)
@ -654,33 +630,27 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if a user set again the presence manually """Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`. and status is `None`, that `status_msg` is `None`.
""" """
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
# Mark user as online # Mark user as online
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
user_id, PresenceState.ONLINE, status_msg
)
# Mark user as online and `status_msg = None` # Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) self._set_presencestate_with_status_msg(PresenceState.ONLINE, None)
def test_set_presence_from_syncing_not_set(self) -> None: def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false""" """Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
user_id, PresenceState.UNAVAILABLE, status_msg
)
self.get_success( self.get_success(
self.presence_handler.user_syncing(user_id, False, PresenceState.ONLINE) self.presence_handler.user_syncing(
self.user_id, False, PresenceState.ONLINE
)
) )
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# we should still be unavailable # we should still be unavailable
self.assertEqual(state.state, PresenceState.UNAVAILABLE) self.assertEqual(state.state, PresenceState.UNAVAILABLE)
# and status message should still be the same # and status message should still be the same
@ -688,50 +658,34 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def test_set_presence_from_syncing_is_set(self) -> None: def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true""" """Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
user_id, PresenceState.UNAVAILABLE, status_msg
)
self.get_success( self.get_success(
self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
) )
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# we should now be online # we should now be online
self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.state, PresenceState.ONLINE)
def test_set_presence_from_syncing_keeps_status(self) -> None: def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message""" """Test that presence set by syncing retains status message"""
user_id = "@test:server"
status_msg = "I'm here!" status_msg = "I'm here!"
self._set_presencestate_with_status_msg( self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
user_id, PresenceState.UNAVAILABLE, status_msg
)
self.get_success( self.get_success(
self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
) )
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# our status message should be the same as it was before # our status message should be the same as it was before
self.assertEqual(state.status_msg, status_msg) self.assertEqual(state.status_msg, status_msg)
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
@unittest.override_config( @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
{
"experimental_features": {
"msc3026_enabled": True,
},
}
)
def test_set_presence_from_syncing_keeps_busy( def test_set_presence_from_syncing_keeps_busy(
self, test_with_workers: bool self, test_with_workers: bool
) -> None: ) -> None:
@ -741,7 +695,6 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
test_with_workers: If True, check the presence state of the user by calling test_with_workers: If True, check the presence state of the user by calling
/sync against a worker, rather than the main process. /sync against a worker, rather than the main process.
""" """
user_id = "@test:server"
status_msg = "I'm busy!" status_msg = "I'm busy!"
# By default, we call /sync against the main process. # By default, we call /sync against the main process.
@ -755,44 +708,39 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
) )
# Set presence to BUSY # Set presence to BUSY
self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) self._set_presencestate_with_status_msg(PresenceState.BUSY, status_msg)
# Perform a sync with a presence state other than busy. This should NOT change # Perform a sync with a presence state other than busy. This should NOT change
# our presence status; we only change from busy if we explicitly set it via # our presence status; we only change from busy if we explicitly set it via
# /presence/*. # /presence/*.
self.get_success( self.get_success(
worker_to_sync_against.get_presence_handler().user_syncing( worker_to_sync_against.get_presence_handler().user_syncing(
user_id, True, PresenceState.ONLINE self.user_id, True, PresenceState.ONLINE
) )
) )
# Check against the main process that the user's presence did not change. # Check against the main process that the user's presence did not change.
state = self.get_success( state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
# we should still be busy # we should still be busy
self.assertEqual(state.state, PresenceState.BUSY) self.assertEqual(state.state, PresenceState.BUSY)
def _set_presencestate_with_status_msg( def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str] self, state: str, status_msg: Optional[str]
) -> None: ) -> None:
"""Set a PresenceState and status_msg and check the result. """Set a PresenceState and status_msg and check the result.
Args: Args:
user_id: User for that the status is to be set.
state: The new PresenceState. state: The new PresenceState.
status_msg: Status message that is to be set. status_msg: Status message that is to be set.
""" """
self.get_success( self.get_success(
self.presence_handler.set_state( self.presence_handler.set_state(
UserID.from_string(user_id), self.user_id_obj,
{"presence": state, "status_msg": status_msg}, {"presence": state, "status_msg": status_msg},
) )
) )
new_state = self.get_success( new_state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
self.presence_handler.get_state(UserID.from_string(user_id))
)
self.assertEqual(new_state.state, state) self.assertEqual(new_state.state, state)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -952,9 +900,6 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertEqual(upto_token, now_token) self.assertEqual(upto_token, now_token)
self.assertFalse(limited) self.assertFalse(limited)
expected_rows = [
(2, ("dest3", "@user3:test")),
]
self.assertCountEqual(rows, []) self.assertCountEqual(rows, [])
prev_token = self.queue.get_current_token(self.instance_name) prev_token = self.queue.get_current_token(self.instance_name)

View File

@ -446,6 +446,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertIsNone(profile) self.assertIsNone(profile)
def test_handle_user_deactivated_support_user(self) -> None: def test_handle_user_deactivated_support_user(self) -> None:
"""Ensure a support user doesn't get added to the user directory after deactivation."""
s_user_id = "@support:test" s_user_id = "@support:test"
self.get_success( self.get_success(
self.store.register_user( self.store.register_user(
@ -453,14 +454,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
) )
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) # The profile should not be in the directory.
with patch.object( profile = self.get_success(self.store._get_user_in_directory(s_user_id))
self.store, "remove_from_user_dir", mock_remove_from_user_dir self.assertIsNone(profile)
):
self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) # Remove the user from the directory.
# BUG: the correct spelling is assert_not_called, but that makes the test fail self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
# and it's not clear that this is actually the behaviour we want.
mock_remove_from_user_dir.not_called() # The profile should still not be in the user directory.
profile = self.get_success(self.store._get_user_in_directory(s_user_id))
self.assertIsNone(profile)
def test_handle_user_deactivated_regular_user(self) -> None: def test_handle_user_deactivated_regular_user(self) -> None:
r_user_id = "@regular:test" r_user_id = "@regular:test"

View File

@ -0,0 +1,62 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import synapse.rest.admin._base
from tests.replication._base import BaseMultiWorkerStreamTestCase
class IntrospectionTokenCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [synapse.rest.admin.register_servlets]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["disable_registration"] = True
config["experimental_features"] = {
"msc3861": {
"enabled": True,
"issuer": "some_dude",
"client_id": "ID",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
}
}
return config
def test_stream_introspection_token_invalidation(self) -> None:
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
auth = worker_hs.get_auth()
store = self.hs.get_datastores().main
# add a token to the cache on the worker
auth._token_cache["open_sesame"] = "intro_token" # type: ignore[attr-defined]
# stream the invalidation from the master
self.get_success(
store.stream_introspection_token_invalidation(("open_sesame",))
)
# check that the cache on the worker was invalidated
self.assertEqual(auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
# test invalidating whole cache
for i in range(0, 5):
auth._token_cache[f"open_sesame_{i}"] = f"intro_token_{i}" # type: ignore[attr-defined]
self.assertEqual(len(auth._token_cache), 5) # type: ignore[attr-defined]
self.get_success(store.stream_introspection_token_invalidation((None,)))
self.assertEqual(len(auth._token_cache), 0) # type: ignore[attr-defined]

View File

@ -1000,8 +1000,6 @@ def setup_test_homeserver(
hs.tls_server_context_factory = Mock() hs.tls_server_context_factory = Mock()
hs.setup() hs.setup()
if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine): if isinstance(db_engine, PostgresEngine):
database_pool = hs.get_datastores().databases[0] database_pool = hs.get_datastores().databases[0]