Add missing type hints to tests.config. (#14681)
This commit is contained in:
parent
864c3f85b0
commit
3aeca2588b
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints.
|
4
mypy.ini
4
mypy.ini
|
@ -36,8 +36,6 @@ exclude = (?x)
|
||||||
|tests/api/test_ratelimiting.py
|
|tests/api/test_ratelimiting.py
|
||||||
|tests/app/test_openid_listener.py
|
|tests/app/test_openid_listener.py
|
||||||
|tests/appservice/test_scheduler.py
|
|tests/appservice/test_scheduler.py
|
||||||
|tests/config/test_cache.py
|
|
||||||
|tests/config/test_tls.py
|
|
||||||
|tests/crypto/test_keyring.py
|
|tests/crypto/test_keyring.py
|
||||||
|tests/events/test_presence_router.py
|
|tests/events/test_presence_router.py
|
||||||
|tests/events/test_utils.py
|
|tests/events/test_utils.py
|
||||||
|
@ -89,7 +87,7 @@ disallow_untyped_defs = False
|
||||||
[mypy-tests.*]
|
[mypy-tests.*]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
[mypy-tests.config.test_api]
|
[mypy-tests.config.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.federation.transport.test_client]
|
[mypy-tests.federation.transport.test_client]
|
||||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Mapping, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ def add_resizable_cache(
|
||||||
|
|
||||||
class CacheConfig(Config):
|
class CacheConfig(Config):
|
||||||
section = "caches"
|
section = "caches"
|
||||||
_environ = os.environ
|
_environ: Mapping[str, str] = os.environ
|
||||||
|
|
||||||
event_cache_size: int
|
event_cache_size: int
|
||||||
cache_factors: Dict[str, float]
|
cache_factors: Dict[str, float]
|
||||||
|
|
|
@ -788,26 +788,21 @@ class LruCache(Generic[KT, VT]):
|
||||||
def __contains__(self, key: KT) -> bool:
|
def __contains__(self, key: KT) -> bool:
|
||||||
return self.contains(key)
|
return self.contains(key)
|
||||||
|
|
||||||
def set_cache_factor(self, factor: float) -> bool:
|
def set_cache_factor(self, factor: float) -> None:
|
||||||
"""
|
"""
|
||||||
Set the cache factor for this individual cache.
|
Set the cache factor for this individual cache.
|
||||||
|
|
||||||
This will trigger a resize if it changes, which may require evicting
|
This will trigger a resize if it changes, which may require evicting
|
||||||
items from the cache.
|
items from the cache.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Whether the cache changed size or not.
|
|
||||||
"""
|
"""
|
||||||
if not self.apply_cache_factor_from_config:
|
if not self.apply_cache_factor_from_config:
|
||||||
return False
|
return
|
||||||
|
|
||||||
new_size = int(self._original_max_size * factor)
|
new_size = int(self._original_max_size * factor)
|
||||||
if new_size != self.max_size:
|
if new_size != self.max_size:
|
||||||
self.max_size = new_size
|
self.max_size = new_size
|
||||||
if self._on_resize:
|
if self._on_resize:
|
||||||
self._on_resize()
|
self._on_resize()
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
# We're about to be deleted, so we make sure to clear up all the nodes
|
# We're about to be deleted, so we make sure to clear up all the nodes
|
||||||
|
|
|
@ -17,15 +17,15 @@ from tests.config.utils import ConfigFileTestCase
|
||||||
|
|
||||||
|
|
||||||
class ConfigMainFileTestCase(ConfigFileTestCase):
|
class ConfigMainFileTestCase(ConfigFileTestCase):
|
||||||
def test_executes_without_an_action(self):
|
def test_executes_without_an_action(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
main(["", "-c", self.config_file])
|
main(["", "-c", self.config_file])
|
||||||
|
|
||||||
def test_read__error_if_key_not_found(self):
|
def test_read__error_if_key_not_found(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
with self.assertRaises(SystemExit):
|
with self.assertRaises(SystemExit):
|
||||||
main(["", "read", "foo.bar.hello", "-c", self.config_file])
|
main(["", "read", "foo.bar.hello", "-c", self.config_file])
|
||||||
|
|
||||||
def test_read__passes_if_key_found(self):
|
def test_read__passes_if_key_found(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
main(["", "read", "server.server_name", "-c", self.config_file])
|
main(["", "read", "server.server_name", "-c", self.config_file])
|
||||||
|
|
|
@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
|
||||||
# Tests that the default values in the config are correctly loaded. Note that the default
|
# Tests that the default values in the config are correctly loaded. Note that the default
|
||||||
# values are loaded when the corresponding config options are commented out, which is why there isn't
|
# values are loaded when the corresponding config options are commented out, which is why there isn't
|
||||||
# a config specified here.
|
# a config specified here.
|
||||||
def test_default_configuration(self):
|
def test_default_configuration(self) -> None:
|
||||||
background_updater = BackgroundUpdater(
|
background_updater = BackgroundUpdater(
|
||||||
self.hs, self.hs.get_datastores().main.db_pool
|
self.hs, self.hs.get_datastores().main.db_pool
|
||||||
)
|
)
|
||||||
|
@ -46,7 +46,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_custom_configuration(self):
|
def test_custom_configuration(self) -> None:
|
||||||
background_updater = BackgroundUpdater(
|
background_updater = BackgroundUpdater(
|
||||||
self.hs, self.hs.get_datastores().main.db_pool
|
self.hs, self.hs.get_datastores().main.db_pool
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,13 +24,13 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class BaseConfigTestCase(unittest.TestCase):
|
class BaseConfigTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
# The root object needs a server property with a public_baseurl.
|
# The root object needs a server property with a public_baseurl.
|
||||||
root = Mock()
|
root = Mock()
|
||||||
root.server.public_baseurl = "http://test"
|
root.server.public_baseurl = "http://test"
|
||||||
self.config = Config(root)
|
self.config = Config(root)
|
||||||
|
|
||||||
def test_loading_missing_templates(self):
|
def test_loading_missing_templates(self) -> None:
|
||||||
# Use a temporary directory that exists on the system, but that isn't likely to
|
# Use a temporary directory that exists on the system, but that isn't likely to
|
||||||
# contain template files
|
# contain template files
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -50,7 +50,7 @@ class BaseConfigTestCase(unittest.TestCase):
|
||||||
"Template file did not contain our test string",
|
"Template file did not contain our test string",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_loading_custom_templates(self):
|
def test_loading_custom_templates(self) -> None:
|
||||||
# Use a temporary directory that exists on the system
|
# Use a temporary directory that exists on the system
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# Create a temporary bogus template file
|
# Create a temporary bogus template file
|
||||||
|
@ -79,7 +79,7 @@ class BaseConfigTestCase(unittest.TestCase):
|
||||||
"Template file did not contain our test string",
|
"Template file did not contain our test string",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multiple_custom_template_directories(self):
|
def test_multiple_custom_template_directories(self) -> None:
|
||||||
"""Tests that directories are searched in the right order if multiple custom
|
"""Tests that directories are searched in the right order if multiple custom
|
||||||
template directories are provided.
|
template directories are provided.
|
||||||
"""
|
"""
|
||||||
|
@ -137,7 +137,7 @@ class BaseConfigTestCase(unittest.TestCase):
|
||||||
for td in tempdirs:
|
for td in tempdirs:
|
||||||
td.cleanup()
|
td.cleanup()
|
||||||
|
|
||||||
def test_loading_template_from_nonexistent_custom_directory(self):
|
def test_loading_template_from_nonexistent_custom_directory(self) -> None:
|
||||||
with self.assertRaises(ConfigError):
|
with self.assertRaises(ConfigError):
|
||||||
self.config.read_templates(
|
self.config.read_templates(
|
||||||
["some_filename.html"], ("a_nonexistent_directory",)
|
["some_filename.html"], ("a_nonexistent_directory",)
|
||||||
|
|
|
@ -13,26 +13,27 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.config.cache import CacheConfig, add_resizable_cache
|
from synapse.config.cache import CacheConfig, add_resizable_cache
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
class CacheConfigTests(TestCase):
|
class CacheConfigTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
# Reset caches before each test since there's global state involved.
|
# Reset caches before each test since there's global state involved.
|
||||||
self.config = CacheConfig()
|
self.config = CacheConfig()
|
||||||
self.config.reset()
|
self.config.reset()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
# Also reset the caches after each test to leave state pristine.
|
# Also reset the caches after each test to leave state pristine.
|
||||||
self.config.reset()
|
self.config.reset()
|
||||||
|
|
||||||
def test_individual_caches_from_environ(self):
|
def test_individual_caches_from_environ(self) -> None:
|
||||||
"""
|
"""
|
||||||
Individual cache factors will be loaded from the environment.
|
Individual cache factors will be loaded from the environment.
|
||||||
"""
|
"""
|
||||||
config = {}
|
config: JsonDict = {}
|
||||||
self.config._environ = {
|
self.config._environ = {
|
||||||
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
|
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
|
||||||
"SYNAPSE_NOT_CACHE": "BLAH",
|
"SYNAPSE_NOT_CACHE": "BLAH",
|
||||||
|
@ -42,15 +43,15 @@ class CacheConfigTests(TestCase):
|
||||||
|
|
||||||
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
|
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
|
||||||
|
|
||||||
def test_config_overrides_environ(self):
|
def test_config_overrides_environ(self) -> None:
|
||||||
"""
|
"""
|
||||||
Individual cache factors defined in the environment will take precedence
|
Individual cache factors defined in the environment will take precedence
|
||||||
over those in the config.
|
over those in the config.
|
||||||
"""
|
"""
|
||||||
config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
|
config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
|
||||||
self.config._environ = {
|
self.config._environ = {
|
||||||
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
|
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
|
||||||
"SYNAPSE_CACHE_FACTOR_FOO": 1,
|
"SYNAPSE_CACHE_FACTOR_FOO": "1",
|
||||||
}
|
}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
@ -60,104 +61,104 @@ class CacheConfigTests(TestCase):
|
||||||
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
|
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_individual_instantiated_before_config_load(self):
|
def test_individual_instantiated_before_config_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a cache is instantiated before the config is read, it will be given
|
If a cache is instantiated before the config is read, it will be given
|
||||||
the default cache size in the interim, and then resized once the config
|
the default cache size in the interim, and then resized once the config
|
||||||
is loaded.
|
is loaded.
|
||||||
"""
|
"""
|
||||||
cache = LruCache(100)
|
cache: LruCache = LruCache(100)
|
||||||
|
|
||||||
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
||||||
self.assertEqual(cache.max_size, 50)
|
self.assertEqual(cache.max_size, 50)
|
||||||
|
|
||||||
config = {"caches": {"per_cache_factors": {"foo": 3}}}
|
config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}}
|
||||||
self.config.read_config(config)
|
self.config.read_config(config)
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
self.assertEqual(cache.max_size, 300)
|
self.assertEqual(cache.max_size, 300)
|
||||||
|
|
||||||
def test_individual_instantiated_after_config_load(self):
|
def test_individual_instantiated_after_config_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a cache is instantiated after the config is read, it will be
|
If a cache is instantiated after the config is read, it will be
|
||||||
immediately resized to the correct size given the per_cache_factor if
|
immediately resized to the correct size given the per_cache_factor if
|
||||||
there is one.
|
there is one.
|
||||||
"""
|
"""
|
||||||
config = {"caches": {"per_cache_factors": {"foo": 2}}}
|
config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
cache = LruCache(100)
|
cache: LruCache = LruCache(100)
|
||||||
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
||||||
self.assertEqual(cache.max_size, 200)
|
self.assertEqual(cache.max_size, 200)
|
||||||
|
|
||||||
def test_global_instantiated_before_config_load(self):
|
def test_global_instantiated_before_config_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a cache is instantiated before the config is read, it will be given
|
If a cache is instantiated before the config is read, it will be given
|
||||||
the default cache size in the interim, and then resized to the new
|
the default cache size in the interim, and then resized to the new
|
||||||
default cache size once the config is loaded.
|
default cache size once the config is loaded.
|
||||||
"""
|
"""
|
||||||
cache = LruCache(100)
|
cache: LruCache = LruCache(100)
|
||||||
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
||||||
self.assertEqual(cache.max_size, 50)
|
self.assertEqual(cache.max_size, 50)
|
||||||
|
|
||||||
config = {"caches": {"global_factor": 4}}
|
config: JsonDict = {"caches": {"global_factor": 4}}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
self.assertEqual(cache.max_size, 400)
|
self.assertEqual(cache.max_size, 400)
|
||||||
|
|
||||||
def test_global_instantiated_after_config_load(self):
|
def test_global_instantiated_after_config_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a cache is instantiated after the config is read, it will be
|
If a cache is instantiated after the config is read, it will be
|
||||||
immediately resized to the correct size given the global factor if there
|
immediately resized to the correct size given the global factor if there
|
||||||
is no per-cache factor.
|
is no per-cache factor.
|
||||||
"""
|
"""
|
||||||
config = {"caches": {"global_factor": 1.5}}
|
config: JsonDict = {"caches": {"global_factor": 1.5}}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
cache = LruCache(100)
|
cache: LruCache = LruCache(100)
|
||||||
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
|
||||||
self.assertEqual(cache.max_size, 150)
|
self.assertEqual(cache.max_size, 150)
|
||||||
|
|
||||||
def test_cache_with_asterisk_in_name(self):
|
def test_cache_with_asterisk_in_name(self) -> None:
|
||||||
"""Some caches have asterisks in their name, test that they are set correctly."""
|
"""Some caches have asterisks in their name, test that they are set correctly."""
|
||||||
|
|
||||||
config = {
|
config: JsonDict = {
|
||||||
"caches": {
|
"caches": {
|
||||||
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
|
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.config._environ = {
|
self.config._environ = {
|
||||||
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
|
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
|
||||||
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
|
"SYNAPSE_CACHE_FACTOR_CACHE_B": "3",
|
||||||
}
|
}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
cache_a = LruCache(100)
|
cache_a: LruCache = LruCache(100)
|
||||||
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
|
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
|
||||||
self.assertEqual(cache_a.max_size, 200)
|
self.assertEqual(cache_a.max_size, 200)
|
||||||
|
|
||||||
cache_b = LruCache(100)
|
cache_b: LruCache = LruCache(100)
|
||||||
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
|
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
|
||||||
self.assertEqual(cache_b.max_size, 300)
|
self.assertEqual(cache_b.max_size, 300)
|
||||||
|
|
||||||
cache_c = LruCache(100)
|
cache_c: LruCache = LruCache(100)
|
||||||
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
|
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
|
||||||
self.assertEqual(cache_c.max_size, 200)
|
self.assertEqual(cache_c.max_size, 200)
|
||||||
|
|
||||||
def test_apply_cache_factor_from_config(self):
|
def test_apply_cache_factor_from_config(self) -> None:
|
||||||
"""Caches can disable applying cache factor updates, mainly used by
|
"""Caches can disable applying cache factor updates, mainly used by
|
||||||
event cache size.
|
event cache size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = {"caches": {"event_cache_size": "10k"}}
|
config: JsonDict = {"caches": {"event_cache_size": "10k"}}
|
||||||
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
self.config.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.config.resize_all_caches()
|
self.config.resize_all_caches()
|
||||||
|
|
||||||
cache = LruCache(
|
cache: LruCache = LruCache(
|
||||||
max_size=self.config.event_cache_size,
|
max_size=self.config.event_cache_size,
|
||||||
apply_cache_factor_from_config=False,
|
apply_cache_factor_from_config=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class DatabaseConfigTestCase(unittest.TestCase):
|
class DatabaseConfigTestCase(unittest.TestCase):
|
||||||
def test_database_configured_correctly(self):
|
def test_database_configured_correctly(self) -> None:
|
||||||
conf = yaml.safe_load(
|
conf = yaml.safe_load(
|
||||||
DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
|
DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,14 +25,14 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class ConfigGenerationTestCase(unittest.TestCase):
|
class ConfigGenerationTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.dir = tempfile.mkdtemp()
|
self.dir = tempfile.mkdtemp()
|
||||||
self.file = os.path.join(self.dir, "homeserver.yaml")
|
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
shutil.rmtree(self.dir)
|
shutil.rmtree(self.dir)
|
||||||
|
|
||||||
def test_generate_config_generates_files(self):
|
def test_generate_config_generates_files(self) -> None:
|
||||||
with redirect_stdout(StringIO()):
|
with redirect_stdout(StringIO()):
|
||||||
HomeServerConfig.load_or_generate_config(
|
HomeServerConfig.load_or_generate_config(
|
||||||
"",
|
"",
|
||||||
|
@ -56,7 +56,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
||||||
os.path.join(os.getcwd(), "homeserver.log"),
|
os.path.join(os.getcwd(), "homeserver.log"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def assert_log_filename_is(self, log_config_file, expected):
|
def assert_log_filename_is(self, log_config_file: str, expected: str) -> None:
|
||||||
with open(log_config_file) as f:
|
with open(log_config_file) as f:
|
||||||
config = f.read()
|
config = f.read()
|
||||||
# find the 'filename' line
|
# find the 'filename' line
|
||||||
|
|
|
@ -21,14 +21,14 @@ from tests.config.utils import ConfigFileTestCase
|
||||||
|
|
||||||
|
|
||||||
class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
||||||
def test_load_fails_if_server_name_missing(self):
|
def test_load_fails_if_server_name_missing(self) -> None:
|
||||||
self.generate_config_and_remove_lines_containing("server_name")
|
self.generate_config_and_remove_lines_containing("server_name")
|
||||||
with self.assertRaises(ConfigError):
|
with self.assertRaises(ConfigError):
|
||||||
HomeServerConfig.load_config("", ["-c", self.config_file])
|
HomeServerConfig.load_config("", ["-c", self.config_file])
|
||||||
with self.assertRaises(ConfigError):
|
with self.assertRaises(ConfigError):
|
||||||
HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
|
HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
|
||||||
|
|
||||||
def test_generates_and_loads_macaroon_secret_key(self):
|
def test_generates_and_loads_macaroon_secret_key(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
|
|
||||||
with open(self.config_file) as f:
|
with open(self.config_file) as f:
|
||||||
|
@ -58,7 +58,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
||||||
"was: %r" % (config2.key.macaroon_secret_key,)
|
"was: %r" % (config2.key.macaroon_secret_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_load_succeeds_if_macaroon_secret_key_missing(self):
|
def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None:
|
||||||
self.generate_config_and_remove_lines_containing("macaroon")
|
self.generate_config_and_remove_lines_containing("macaroon")
|
||||||
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
|
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
|
||||||
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
|
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
|
||||||
|
@ -73,7 +73,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
||||||
config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
|
config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_disable_registration(self):
|
def test_disable_registration(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
self.add_lines_to_config(
|
self.add_lines_to_config(
|
||||||
["enable_registration: true", "disable_registration: true"]
|
["enable_registration: true", "disable_registration: true"]
|
||||||
|
@ -93,7 +93,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
||||||
assert config3 is not None
|
assert config3 is not None
|
||||||
self.assertTrue(config3.registration.enable_registration)
|
self.assertTrue(config3.registration.enable_registration)
|
||||||
|
|
||||||
def test_stats_enabled(self):
|
def test_stats_enabled(self) -> None:
|
||||||
self.generate_config_and_remove_lines_containing("enable_metrics")
|
self.generate_config_and_remove_lines_containing("enable_metrics")
|
||||||
self.add_lines_to_config(["enable_metrics: true"])
|
self.add_lines_to_config(["enable_metrics: true"])
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
|
||||||
config = HomeServerConfig.load_config("", ["-c", self.config_file])
|
config = HomeServerConfig.load_config("", ["-c", self.config_file])
|
||||||
self.assertFalse(config.metrics.metrics_flags.known_servers)
|
self.assertFalse(config.metrics.metrics_flags.known_servers)
|
||||||
|
|
||||||
def test_depreciated_identity_server_flag_throws_error(self):
|
def test_depreciated_identity_server_flag_throws_error(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
# Needed to ensure that actual key/value pair added below don't end up on a line with a comment
|
# Needed to ensure that actual key/value pair added below don't end up on a line with a comment
|
||||||
self.add_lines_to_config([" "])
|
self.add_lines_to_config([" "])
|
||||||
|
|
|
@ -18,7 +18,7 @@ from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
class RatelimitConfigTestCase(TestCase):
|
class RatelimitConfigTestCase(TestCase):
|
||||||
def test_parse_rc_federation(self):
|
def test_parse_rc_federation(self) -> None:
|
||||||
config_dict = default_config("test")
|
config_dict = default_config("test")
|
||||||
config_dict["rc_federation"] = {
|
config_dict["rc_federation"] = {
|
||||||
"window_size": 20000,
|
"window_size": 20000,
|
||||||
|
|
|
@ -21,7 +21,7 @@ from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
class RegistrationConfigTestCase(ConfigFileTestCase):
|
class RegistrationConfigTestCase(ConfigFileTestCase):
|
||||||
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
|
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None:
|
||||||
"""
|
"""
|
||||||
session_lifetime should logically be larger than, or at least as large as,
|
session_lifetime should logically be larger than, or at least as large as,
|
||||||
all the different token lifetimes.
|
all the different token lifetimes.
|
||||||
|
@ -91,7 +91,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_refuse_to_start_if_open_registration_and_no_verification(self):
|
def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
self.add_lines_to_config(
|
self.add_lines_to_config(
|
||||||
[
|
[
|
||||||
|
|
|
@ -20,7 +20,7 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class RoomDirectoryConfigTestCase(unittest.TestCase):
|
class RoomDirectoryConfigTestCase(unittest.TestCase):
|
||||||
def test_alias_creation_acl(self):
|
def test_alias_creation_acl(self) -> None:
|
||||||
config = yaml.safe_load(
|
config = yaml.safe_load(
|
||||||
"""
|
"""
|
||||||
alias_creation_rules:
|
alias_creation_rules:
|
||||||
|
@ -78,7 +78,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_room_publish_acl(self):
|
def test_room_publish_acl(self) -> None:
|
||||||
config = yaml.safe_load(
|
config = yaml.safe_load(
|
||||||
"""
|
"""
|
||||||
alias_creation_rules: []
|
alias_creation_rules: []
|
||||||
|
|
|
@ -21,7 +21,7 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class ServerConfigTestCase(unittest.TestCase):
|
class ServerConfigTestCase(unittest.TestCase):
|
||||||
def test_is_threepid_reserved(self):
|
def test_is_threepid_reserved(self) -> None:
|
||||||
user1 = {"medium": "email", "address": "user1@example.com"}
|
user1 = {"medium": "email", "address": "user1@example.com"}
|
||||||
user2 = {"medium": "email", "address": "user2@example.com"}
|
user2 = {"medium": "email", "address": "user2@example.com"}
|
||||||
user3 = {"medium": "email", "address": "user3@example.com"}
|
user3 = {"medium": "email", "address": "user3@example.com"}
|
||||||
|
@ -32,7 +32,7 @@ class ServerConfigTestCase(unittest.TestCase):
|
||||||
self.assertFalse(is_threepid_reserved(config, user3))
|
self.assertFalse(is_threepid_reserved(config, user3))
|
||||||
self.assertFalse(is_threepid_reserved(config, user1_msisdn))
|
self.assertFalse(is_threepid_reserved(config, user1_msisdn))
|
||||||
|
|
||||||
def test_unsecure_listener_no_listeners_open_private_ports_false(self):
|
def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None:
|
||||||
conf = yaml.safe_load(
|
conf = yaml.safe_load(
|
||||||
ServerConfig().generate_config_section(
|
ServerConfig().generate_config_section(
|
||||||
"CONFDIR", "/data_dir_path", "che.org", False, None
|
"CONFDIR", "/data_dir_path", "che.org", False, None
|
||||||
|
@ -52,7 +52,7 @@ class ServerConfigTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(conf["listeners"], expected_listeners)
|
self.assertEqual(conf["listeners"], expected_listeners)
|
||||||
|
|
||||||
def test_unsecure_listener_no_listeners_open_private_ports_true(self):
|
def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None:
|
||||||
conf = yaml.safe_load(
|
conf = yaml.safe_load(
|
||||||
ServerConfig().generate_config_section(
|
ServerConfig().generate_config_section(
|
||||||
"CONFDIR", "/data_dir_path", "che.org", True, None
|
"CONFDIR", "/data_dir_path", "che.org", True, None
|
||||||
|
@ -71,7 +71,7 @@ class ServerConfigTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(conf["listeners"], expected_listeners)
|
self.assertEqual(conf["listeners"], expected_listeners)
|
||||||
|
|
||||||
def test_listeners_set_correctly_open_private_ports_false(self):
|
def test_listeners_set_correctly_open_private_ports_false(self) -> None:
|
||||||
listeners = [
|
listeners = [
|
||||||
{
|
{
|
||||||
"port": 8448,
|
"port": 8448,
|
||||||
|
@ -95,7 +95,7 @@ class ServerConfigTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(conf["listeners"], listeners)
|
self.assertEqual(conf["listeners"], listeners)
|
||||||
|
|
||||||
def test_listeners_set_correctly_open_private_ports_true(self):
|
def test_listeners_set_correctly_open_private_ports_true(self) -> None:
|
||||||
listeners = [
|
listeners = [
|
||||||
{
|
{
|
||||||
"port": 8448,
|
"port": 8448,
|
||||||
|
@ -131,14 +131,14 @@ class ServerConfigTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class GenerateIpSetTestCase(unittest.TestCase):
|
class GenerateIpSetTestCase(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self) -> None:
|
||||||
ip_set = generate_ip_set(())
|
ip_set = generate_ip_set(())
|
||||||
self.assertFalse(ip_set)
|
self.assertFalse(ip_set)
|
||||||
|
|
||||||
ip_set = generate_ip_set((), ())
|
ip_set = generate_ip_set((), ())
|
||||||
self.assertFalse(ip_set)
|
self.assertFalse(ip_set)
|
||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self) -> None:
|
||||||
"""Check adding IPv4 and IPv6 addresses."""
|
"""Check adding IPv4 and IPv6 addresses."""
|
||||||
# IPv4 address
|
# IPv4 address
|
||||||
ip_set = generate_ip_set(("1.2.3.4",))
|
ip_set = generate_ip_set(("1.2.3.4",))
|
||||||
|
@ -160,7 +160,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
|
||||||
ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
|
ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
|
||||||
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
||||||
|
|
||||||
def test_extra(self):
|
def test_extra(self) -> None:
|
||||||
"""Extra IP addresses are treated the same."""
|
"""Extra IP addresses are treated the same."""
|
||||||
ip_set = generate_ip_set((), ("1.2.3.4",))
|
ip_set = generate_ip_set((), ("1.2.3.4",))
|
||||||
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
||||||
|
@ -172,7 +172,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
|
||||||
ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
|
ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
|
||||||
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
self.assertEqual(len(ip_set.iter_cidrs()), 4)
|
||||||
|
|
||||||
def test_bad_value(self):
|
def test_bad_value(self) -> None:
|
||||||
"""An error should be raised if a bad value is passed in."""
|
"""An error should be raised if a bad value is passed in."""
|
||||||
with self.assertRaises(ConfigError):
|
with self.assertRaises(ConfigError):
|
||||||
generate_ip_set(("not-an-ip",))
|
generate_ip_set(("not-an-ip",))
|
||||||
|
|
|
@ -13,13 +13,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import idna
|
import idna
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from synapse.config._base import Config, RootConfig
|
from synapse.config._base import Config, RootConfig
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.tls import ConfigError, TlsConfig
|
from synapse.config.tls import ConfigError, TlsConfig
|
||||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
from synapse.crypto.context_factory import (
|
||||||
|
FederationPolicyForHTTPS,
|
||||||
|
SSLClientConnectionCreator,
|
||||||
|
)
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
@ -27,7 +34,7 @@ from tests.unittest import TestCase
|
||||||
class FakeServer(Config):
|
class FakeServer(Config):
|
||||||
section = "server"
|
section = "server"
|
||||||
|
|
||||||
def has_tls_listener(self):
|
def has_tls_listener(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,21 +43,21 @@ class TestConfig(RootConfig):
|
||||||
|
|
||||||
|
|
||||||
class TLSConfigTests(TestCase):
|
class TLSConfigTests(TestCase):
|
||||||
def test_tls_client_minimum_default(self):
|
def test_tls_client_minimum_default(self) -> None:
|
||||||
"""
|
"""
|
||||||
The default client TLS version is 1.0.
|
The default client TLS version is 1.0.
|
||||||
"""
|
"""
|
||||||
config = {}
|
config: JsonDict = {}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
|
|
||||||
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
|
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
|
||||||
|
|
||||||
def test_tls_client_minimum_set(self):
|
def test_tls_client_minimum_set(self) -> None:
|
||||||
"""
|
"""
|
||||||
The default client TLS version can be set to 1.0, 1.1, and 1.2.
|
The default client TLS version can be set to 1.0, 1.1, and 1.2.
|
||||||
"""
|
"""
|
||||||
config = {"federation_client_minimum_tls_version": 1}
|
config: JsonDict = {"federation_client_minimum_tls_version": 1}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
|
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
|
||||||
|
@ -76,7 +83,7 @@ class TLSConfigTests(TestCase):
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
|
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
|
||||||
|
|
||||||
def test_tls_client_minimum_1_point_3_missing(self):
|
def test_tls_client_minimum_1_point_3_missing(self) -> None:
|
||||||
"""
|
"""
|
||||||
If TLS 1.3 support is missing and it's configured, it will raise a
|
If TLS 1.3 support is missing and it's configured, it will raise a
|
||||||
ConfigError.
|
ConfigError.
|
||||||
|
@ -88,7 +95,7 @@ class TLSConfigTests(TestCase):
|
||||||
self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
|
self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
|
||||||
assert not hasattr(SSL, "OP_NO_TLSv1_3")
|
assert not hasattr(SSL, "OP_NO_TLSv1_3")
|
||||||
|
|
||||||
config = {"federation_client_minimum_tls_version": 1.3}
|
config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
with self.assertRaises(ConfigError) as e:
|
with self.assertRaises(ConfigError) as e:
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
|
@ -100,7 +107,7 @@ class TLSConfigTests(TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_tls_client_minimum_1_point_3_exists(self):
|
def test_tls_client_minimum_1_point_3_exists(self) -> None:
|
||||||
"""
|
"""
|
||||||
If TLS 1.3 support exists and it's configured, it will be settable.
|
If TLS 1.3 support exists and it's configured, it will be settable.
|
||||||
"""
|
"""
|
||||||
|
@ -110,20 +117,20 @@ class TLSConfigTests(TestCase):
|
||||||
self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
|
self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
|
||||||
assert hasattr(SSL, "OP_NO_TLSv1_3")
|
assert hasattr(SSL, "OP_NO_TLSv1_3")
|
||||||
|
|
||||||
config = {"federation_client_minimum_tls_version": 1.3}
|
config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
|
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
|
||||||
|
|
||||||
def test_tls_client_minimum_set_passed_through_1_2(self):
|
def test_tls_client_minimum_set_passed_through_1_2(self) -> None:
|
||||||
"""
|
"""
|
||||||
The configured TLS version is correctly configured by the ContextFactory.
|
The configured TLS version is correctly configured by the ContextFactory.
|
||||||
"""
|
"""
|
||||||
config = {"federation_client_minimum_tls_version": 1.2}
|
config: JsonDict = {"federation_client_minimum_tls_version": 1.2}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
|
|
||||||
cf = FederationPolicyForHTTPS(t)
|
cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
|
||||||
options = _get_ssl_context_options(cf._verify_ssl_context)
|
options = _get_ssl_context_options(cf._verify_ssl_context)
|
||||||
|
|
||||||
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
|
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
|
||||||
|
@ -131,15 +138,15 @@ class TLSConfigTests(TestCase):
|
||||||
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
||||||
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
||||||
|
|
||||||
def test_tls_client_minimum_set_passed_through_1_0(self):
|
def test_tls_client_minimum_set_passed_through_1_0(self) -> None:
|
||||||
"""
|
"""
|
||||||
The configured TLS version is correctly configured by the ContextFactory.
|
The configured TLS version is correctly configured by the ContextFactory.
|
||||||
"""
|
"""
|
||||||
config = {"federation_client_minimum_tls_version": 1}
|
config: JsonDict = {"federation_client_minimum_tls_version": 1}
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
|
|
||||||
cf = FederationPolicyForHTTPS(t)
|
cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
|
||||||
options = _get_ssl_context_options(cf._verify_ssl_context)
|
options = _get_ssl_context_options(cf._verify_ssl_context)
|
||||||
|
|
||||||
# The context has not had any of the NO_TLS set.
|
# The context has not had any of the NO_TLS set.
|
||||||
|
@ -147,11 +154,11 @@ class TLSConfigTests(TestCase):
|
||||||
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
||||||
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
||||||
|
|
||||||
def test_whitelist_idna_failure(self):
|
def test_whitelist_idna_failure(self) -> None:
|
||||||
"""
|
"""
|
||||||
The federation certificate whitelist will not allow IDNA domain names.
|
The federation certificate whitelist will not allow IDNA domain names.
|
||||||
"""
|
"""
|
||||||
config = {
|
config: JsonDict = {
|
||||||
"federation_certificate_verification_whitelist": [
|
"federation_certificate_verification_whitelist": [
|
||||||
"example.com",
|
"example.com",
|
||||||
"*.ドメイン.テスト",
|
"*.ドメイン.テスト",
|
||||||
|
@ -163,11 +170,11 @@ class TLSConfigTests(TestCase):
|
||||||
)
|
)
|
||||||
self.assertIn("IDNA domain names", str(e))
|
self.assertIn("IDNA domain names", str(e))
|
||||||
|
|
||||||
def test_whitelist_idna_result(self):
|
def test_whitelist_idna_result(self) -> None:
|
||||||
"""
|
"""
|
||||||
The federation certificate whitelist will match on IDNA encoded names.
|
The federation certificate whitelist will match on IDNA encoded names.
|
||||||
"""
|
"""
|
||||||
config = {
|
config: JsonDict = {
|
||||||
"federation_certificate_verification_whitelist": [
|
"federation_certificate_verification_whitelist": [
|
||||||
"example.com",
|
"example.com",
|
||||||
"*.xn--eckwd4c7c.xn--zckzah",
|
"*.xn--eckwd4c7c.xn--zckzah",
|
||||||
|
@ -176,14 +183,16 @@ class TLSConfigTests(TestCase):
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
t.tls.read_config(config, config_dir_path="", data_dir_path="")
|
||||||
|
|
||||||
cf = FederationPolicyForHTTPS(t)
|
cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
|
||||||
|
|
||||||
# Not in the whitelist
|
# Not in the whitelist
|
||||||
opts = cf.get_options(b"notexample.com")
|
opts = cf.get_options(b"notexample.com")
|
||||||
|
assert isinstance(opts, SSLClientConnectionCreator)
|
||||||
self.assertTrue(opts._verifier._verify_certs)
|
self.assertTrue(opts._verifier._verify_certs)
|
||||||
|
|
||||||
# Caught by the wildcard
|
# Caught by the wildcard
|
||||||
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
|
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
|
||||||
|
assert isinstance(opts, SSLClientConnectionCreator)
|
||||||
self.assertFalse(opts._verifier._verify_certs)
|
self.assertFalse(opts._verifier._verify_certs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
|
||||||
"""get the options bits from an openssl context object"""
|
"""get the options bits from an openssl context object"""
|
||||||
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
|
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
|
||||||
# use the low-level interface
|
# use the low-level interface
|
||||||
return SSL._lib.SSL_CTX_get_options(ssl_context._context)
|
return SSL._lib.SSL_CTX_get_options(ssl_context._context) # type: ignore[attr-defined]
|
||||||
|
|
|
@ -21,7 +21,7 @@ from tests.unittest import TestCase
|
||||||
class ValidateConfigTestCase(TestCase):
|
class ValidateConfigTestCase(TestCase):
|
||||||
"""Test cases for synapse.config._util.validate_config"""
|
"""Test cases for synapse.config._util.validate_config"""
|
||||||
|
|
||||||
def test_bad_object_in_array(self):
|
def test_bad_object_in_array(self) -> None:
|
||||||
"""malformed objects within an array should be validated correctly"""
|
"""malformed objects within an array should be validated correctly"""
|
||||||
|
|
||||||
# consider a structure:
|
# consider a structure:
|
||||||
|
|
|
@ -17,19 +17,20 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
|
||||||
|
|
||||||
class ConfigFileTestCase(unittest.TestCase):
|
class ConfigFileTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.dir = tempfile.mkdtemp()
|
self.dir = tempfile.mkdtemp()
|
||||||
self.config_file = os.path.join(self.dir, "homeserver.yaml")
|
self.config_file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
shutil.rmtree(self.dir)
|
shutil.rmtree(self.dir)
|
||||||
|
|
||||||
def generate_config(self):
|
def generate_config(self) -> None:
|
||||||
with redirect_stdout(StringIO()):
|
with redirect_stdout(StringIO()):
|
||||||
HomeServerConfig.load_or_generate_config(
|
HomeServerConfig.load_or_generate_config(
|
||||||
"",
|
"",
|
||||||
|
@ -43,7 +44,7 @@ class ConfigFileTestCase(unittest.TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_config_and_remove_lines_containing(self, needle):
|
def generate_config_and_remove_lines_containing(self, needle: str) -> None:
|
||||||
self.generate_config()
|
self.generate_config()
|
||||||
|
|
||||||
with open(self.config_file) as f:
|
with open(self.config_file) as f:
|
||||||
|
@ -52,7 +53,7 @@ class ConfigFileTestCase(unittest.TestCase):
|
||||||
with open(self.config_file, "w") as f:
|
with open(self.config_file, "w") as f:
|
||||||
f.write("".join(contents))
|
f.write("".join(contents))
|
||||||
|
|
||||||
def add_lines_to_config(self, lines):
|
def add_lines_to_config(self, lines: List[str]) -> None:
|
||||||
with open(self.config_file, "a") as f:
|
with open(self.config_file, "a") as f:
|
||||||
for line in lines:
|
for line in lines:
|
||||||
f.write(line + "\n")
|
f.write(line + "\n")
|
||||||
|
|
Loading…
Reference in New Issue