Add type hints for most `HomeServer` parameters (#11095)

This commit is contained in:
Sean Quah 2021-10-22 18:15:41 +01:00 committed by GitHub
parent b9ce53e878
commit 2b82ec425f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 342 additions and 143 deletions

1
changelog.d/11095.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to most `HomeServer` parameters.

View File

@ -294,7 +294,7 @@ def listen_ssl(
return r return r
def refresh_certificate(hs): def refresh_certificate(hs: "HomeServer"):
""" """
Refresh the TLS certificates that Synapse is using by re-reading them from Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them. disk and updating the TLS context factories to use them.
@ -419,11 +419,11 @@ async def start(hs: "HomeServer"):
atexit.register(gc.freeze) atexit.register(gc.freeze)
def setup_sentry(hs): def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration """Enable sentry integration, if enabled in configuration
Args: Args:
hs (synapse.server.HomeServer) hs
""" """
if not hs.config.metrics.sentry_enabled: if not hs.config.metrics.sentry_enabled:
@ -449,7 +449,7 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name) scope.set_tag("worker_name", name)
def setup_sdnotify(hs): def setup_sdnotify(hs: "HomeServer"):
"""Adds process state hooks to tell systemd what we are up to.""" """Adds process state hooks to tell systemd what we are up to."""
# Tell systemd our state, if we're using it. This will silently fail if # Tell systemd our state, if we're using it. This will silently fail if

View File

@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore DATASTORE_CLASS = AdminCmdSlavedStore
async def export_data_command(hs, args): async def export_data_command(hs: HomeServer, args):
"""Export data for a user. """Export data for a user.
Args: Args:
hs (HomeServer) hs
args (argparse.Namespace) args (argparse.Namespace)
""" """

View File

@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs: HomeServer):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs: server
""" """
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()

View File

@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__ e = e.__cause__
def run(hs): def run(hs: HomeServer):
PROFILE_SYNAPSE = False PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE: if PROFILE_SYNAPSE:

View File

@ -15,11 +15,15 @@ import logging
import math import math
import resource import resource
import sys import sys
from typing import TYPE_CHECKING
from prometheus_client import Gauge from prometheus_client import Gauge
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
# Contains the list of processes we will be monitoring # Contains the list of processes we will be monitoring
@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge(
@wrap_as_background_process("phone_stats_home") @wrap_as_background_process("phone_stats_home")
async def phone_stats_home(hs, stats, stats_process=_stats_process): async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time()) now = int(hs.get_clock().time())
uptime = int(now - hs.start_time) uptime = int(now - hs.start_time)
@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.warning("Error reporting stats: %s", e) logger.warning("Error reporting stats: %s", e)
def start_phone_stats_home(hs): def start_phone_stats_home(hs: "HomeServer"):
""" """
Start the background tasks which report phone home stats. Start the background tasks which report phone home stats.
""" """

View File

@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing. pushing.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()

View File

@ -18,6 +18,7 @@ import os
import sys import sys
import threading import threading
from string import Template from string import Template
from typing import TYPE_CHECKING
import yaml import yaml
from zope.interface import implementer from zope.interface import implementer
@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string
from ._base import Config, ConfigError from ._base import Config, ConfigError
if TYPE_CHECKING:
from synapse.server import HomeServer
DEFAULT_LOG_CONFIG = Template( DEFAULT_LOG_CONFIG = Template(
"""\ """\
# Log configuration for Synapse. # Log configuration for Synapse.
@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path):
def setup_logging( def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner hs: "HomeServer",
config,
use_worker_options=False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None: ) -> None:
""" """
Set up the logging subsystem. Set up the logging subsystem.

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationBase: class FederationBase:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname

View File

@ -467,7 +467,7 @@ class FederationServer(FederationBase):
async def on_room_state_request( async def on_room_state_request(
self, origin: str, room_id: str, event_id: Optional[str] self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, JsonDict]:
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -481,7 +481,7 @@ class FederationServer(FederationBase):
# - but that's non-trivial to get right, and anyway somewhat defeats # - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer. # the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))): with (await self._server_linearizer.queue((origin, room_id))):
resp = dict( resp: JsonDict = dict(
await self._state_resp_cache.wrap( await self._state_resp_cache.wrap(
(room_id, event_id), (room_id, event_id),
self._on_context_state_request_compute, self._on_context_state_request_compute,
@ -1061,11 +1061,12 @@ class FederationServer(FederationBase):
origin, event = next origin, event = next
lock = await self.store.try_acquire_lock( new_lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
) )
if not lock: if not new_lock:
return return
lock = new_lock
def __str__(self) -> str: def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -21,6 +21,7 @@ import typing
import urllib.parse import urllib.parse
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import ( from typing import (
TYPE_CHECKING,
Callable, Callable,
Dict, Dict,
Generic, Generic,
@ -73,6 +74,9 @@ from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter( outgoing_requests_counter = Counter(
@ -319,7 +323,7 @@ class MatrixFederationHttpClient:
requests. requests.
""" """
def __init__(self, hs, tls_client_options_factory): def __init__(self, hs: "HomeServer", tls_client_options_factory):
self.hs = hs self.hs = hs
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
self.server_name = hs.hostname self.server_name = hs.hostname
@ -711,7 +715,7 @@ class MatrixFederationHttpClient:
Returns: Returns:
A list of headers to be added as "Authorization:" headers A list of headers to be added as "Authorization:" headers
""" """
request = { request: JsonDict = {
"method": method.decode("ascii"), "method": method.decode("ascii"),
"uri": url_bytes.decode("ascii"), "uri": url_bytes.decode("ascii"),
"origin": self.server_name, "origin": self.server_name,

View File

@ -22,6 +22,7 @@ import urllib
from http import HTTPStatus from http import HTTPStatus
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -61,6 +62,9 @@ from synapse.util import json_encoder
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HTML_ERROR_TEMPLATE = """<!DOCTYPE html> HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource):
return_json_error(f, request) return_json_error(f, request)
_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
class JsonResource(DirectServeJsonResource): class JsonResource(DirectServeJsonResource):
"""This implements the HttpServer interface and provides JSON support for """This implements the HttpServer interface and provides JSON support for
Resources. Resources.
@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
_PathEntry = collections.namedtuple( def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True, extract_context=False):
super().__init__(canonical_json, extract_context) super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback, servlet_classname): def register_paths(self, method, path_patterns, callback, servlet_classname):
@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname) _PathEntry(path_pattern, callback, servlet_classname)
) )
def _get_handler_for_request( def _get_handler_for_request(

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import ( from synapse.replication.http import (
account_data, account_data,
@ -26,16 +28,19 @@ from synapse.replication.http import (
streams, streams,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource): class ReplicationRestResource(JsonResource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
# We enable extracting jaeger contexts here as these are internal APIs. # We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True) super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs) self.register_servlets(hs)
def register_servlets(self, hs): def register_servlets(self, hs: "HomeServer"):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
presence.register_servlets(hs, self) presence.register_servlets(hs, self)

View File

@ -17,7 +17,7 @@ import logging
import re import re
import urllib import urllib
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pass pass
@classmethod @classmethod
def make_client(cls, hs): def make_client(cls, hs: "HomeServer"):
"""Create a client that makes requests. """Create a client that makes requests.
Returns a callable that accepts the same parameters as Returns a callable that accepts the same parameters as
@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
url_args.append(txn_id) url_args.append(txn_id)
if cls.METHOD == "POST": if cls.METHOD == "POST":
request_func = client.post_json_get_json request_func: Callable[
..., Awaitable[Any]
] = client.post_json_get_json
elif cls.METHOD == "PUT": elif cls.METHOD == "PUT":
request_func = client.put_json request_func = client.put_json
elif cls.METHOD == "GET": elif cls.METHOD == "GET":

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "account_data_type") PATH_ARGS = ("user_id", "account_data_type")
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "room_id", "account_data_type") PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "room_id", "tag") PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
) )
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server)

View File

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.device_list_updater = hs.get_device_handler().device_list_updater self.device_list_updater = hs.get_device_handler().device_list_updater
@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return 200, user_devices return 200, user_devices
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUserDevicesResyncRestServlet(hs).register(http_server)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
NAME = "fed_send_events" NAME = "fed_send_events"
PATH_ARGS = () PATH_ARGS = ()
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
NAME = "fed_send_edu" NAME = "fed_send_edu"
PATH_ARGS = ("edu_type",) PATH_ARGS = ("edu_type",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
# This is a query, so let's not bother caching # This is a query, so let's not bother caching
CACHE = False CACHE = False
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
NAME = "fed_cleanup_room" NAME = "fed_cleanup_room"
PATH_ARGS = ("room_id",) PATH_ARGS = ("room_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
NAME = "store_room_on_outlier_membership" NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",) PATH_ARGS = ("room_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
NAME = "device_check_registered" NAME = "device_check_registered"
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
return 200, res return 200, res
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
RegisterDeviceReplicationServlet(hs).register(http_server) RegisterDeviceReplicationServlet(hs).register(http_server)

View File

@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
NAME = "remote_join" NAME = "remote_join"
PATH_ARGS = ("room_id", "user_id") PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.federation_handler = hs.get_federation_handler() self.federation_handler = hs.get_federation_handler()
@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
PATH_ARGS = ("room_id", "user_id", "change") PATH_ARGS = ("room_id", "user_id", "change")
CACHE = False # No point caching as should return instantly. CACHE = False # No point caching as should return instantly.
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.registeration_handler = hs.get_registration_handler() self.registeration_handler = hs.get_registration_handler()
@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)

View File

@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server)

View File

@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationRemovePusherRestServlet(hs).register(http_server) ReplicationRemovePusherRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
NAME = "register_user" NAME = "register_user"
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
NAME = "post_register" NAME = "post_register"
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationRegisterServlet(hs).register(http_server) ReplicationRegisterServlet(hs).register(http_server)
ReplicationPostRegisterActionsServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
NAME = "send_event" NAME = "send_event"
PATH_ARGS = ("event_id",) PATH_ARGS = ("event_id",)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationSendEventRestServlet(hs).register(http_server) ReplicationSendEventRestServlet(hs).register(http_server)

View File

@ -13,11 +13,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
PATH_ARGS = ("stream_name",) PATH_ARGS = ("stream_name",)
METHOD = "GET" METHOD = "GET"
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
ReplicationGetStreamUpdates(hs).register(http_server) ReplicationGetStreamUpdates(hs).register(http_server)

View File

@ -13,18 +13,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[ self._cache_id_gen: Optional[

View File

@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedClientIpStore(BaseSlavedStore): class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache( self.client_ip_last_seen: LruCache[tuple, int] = LruCache(

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,7 +58,7 @@ class SlavedEventStore(
RelationsWorkerStore, RelationsWorkerStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.filtering import FilteringStore from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedFilteringStore(BaseSlavedStore): class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired # Filters are immutable so this cache doesn't need to be expired

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream from synapse.replication.tcp.streams import GroupServerStream
@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View File

@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING: if TYPE_CHECKING:
from txredisapi import RedisProtocol
from synapse.server import HomeServer from synapse.server import HomeServer
set_counter = Counter( set_counter = Counter(
@ -59,7 +61,12 @@ class ExternalCache:
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection() if hs.config.redis.redis_enabled:
self._redis_connection: Optional[
"RedisProtocol"
] = hs.get_outbound_redis_connection()
else:
self._redis_connection = None
def _get_redis_key(self, cache_name: str, key: str) -> str: def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key) return "cache_v1:%s:%s" % (cache_name, key)

View File

@ -294,7 +294,7 @@ class ReplicationCommandHandler:
# This shouldn't be possible # This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME) raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
def start_replication(self, hs): def start_replication(self, hs: "HomeServer"):
"""Helper method to start a replication connection to the remote server """Helper method to start a replication connection to the remote server
using TCP. using TCP.
""" """
@ -321,6 +321,8 @@ class ReplicationCommandHandler:
hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port, hs.config.redis.redis_port,
self._factory, self._factory,
timeout=30,
bindAddress=None,
) )
else: else:
client_name = hs.get_instance_name() client_name = hs.get_instance_name()
@ -331,6 +333,8 @@ class ReplicationCommandHandler:
host, # type: ignore[arg-type] host, # type: ignore[arg-type]
port, port,
self._factory, self._factory,
timeout=30,
bindAddress=None,
) )
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:

View File

@ -16,6 +16,7 @@
import logging import logging
import random import random
from typing import TYPE_CHECKING
from prometheus_client import Counter from prometheus_client import Counter
@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
stream_updates_counter = Counter( stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
) )
@ -37,7 +41,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory): class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections.""" """Factory for new replication connections."""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.command_handler = hs.get_tcp_replication() self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
@ -65,7 +69,7 @@ class ReplicationStreamer:
data is available it will propagate to all connected clients. data is available it will propagate to all connected clients.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()

View File

@ -241,7 +241,7 @@ class BackfillStream(Stream):
NAME = "backfill" NAME = "backfill"
ROW_TYPE = BackfillStreamRow ROW_TYPE = BackfillStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -363,7 +363,7 @@ class ReceiptsStream(Stream):
NAME = "receipts" NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow ROW_TYPE = ReceiptsStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -380,7 +380,7 @@ class PushRulesStream(Stream):
NAME = "push_rules" NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow ROW_TYPE = PushRulesStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
@ -405,7 +405,7 @@ class PushersStream(Stream):
NAME = "pushers" NAME = "pushers"
ROW_TYPE = PushersStreamRow ROW_TYPE = PushersStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
@ -438,7 +438,7 @@ class CachesStream(Stream):
NAME = "caches" NAME = "caches"
ROW_TYPE = CachesStreamRow ROW_TYPE = CachesStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -459,7 +459,7 @@ class DeviceListsStream(Stream):
NAME = "device_lists" NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -476,7 +476,7 @@ class ToDeviceStream(Stream):
NAME = "to_device" NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow ROW_TYPE = ToDeviceStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -495,7 +495,7 @@ class TagAccountDataStream(Stream):
NAME = "tag_account_data" NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -582,7 +582,7 @@ class GroupServerStream(Stream):
NAME = "groups" NAME = "groups"
ROW_TYPE = GroupsStreamRow ROW_TYPE = GroupsStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
@ -599,7 +599,7 @@ class UserSignatureStream(Stream):
NAME = "user_signature" NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),

View File

@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs: server
""" """
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()

View File

@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta):
return ExternalCache(self) return ExternalCache(self)
@cache_in_self @cache_in_self
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: def get_outbound_redis_connection(self) -> "RedisProtocol":
if not self.config.redis.redis_enabled: """
return None The Redis connection used for replication.
Raises:
AssertionError: if Redis is not enabled in the homeserver config.
"""
assert self.config.redis.redis_enabled
# We only want to import redis module if we're using it, as we have # We only want to import redis module if we're using it, as we have
# `txredisapi` as an optional dependency. # `txredisapi` as an optional dependency.

View File

@ -19,6 +19,7 @@ from collections import defaultdict
from sys import intern from sys import intern
from time import monotonic as monotonic_time from time import monotonic as monotonic_time
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Collection, Collection,
@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1 MAX_TXN_ID = 2 ** 63 - 1
@ -392,7 +396,7 @@ class DatabasePool:
def __init__( def __init__(
self, self,
hs, hs: "HomeServer",
database_config: DatabaseConnectionConfig, database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine, engine: BaseDatabaseEngine,
): ):

View File

@ -13,33 +13,49 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Databases: DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
class Databases(Generic[DataStoreT]):
"""The various databases. """The various databases.
These are low level interfaces to physical databases. These are low level interfaces to physical databases.
Attributes: Attributes:
main (DataStore) databases
main
state
persist_events
""" """
def __init__(self, main_store_class, hs): databases: List[DatabasePool]
main: DataStoreT
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main # Note we pass in the main store class here as workers use a different main
# store. # store.
self.databases = [] self.databases = []
main = None main: Optional[DataStoreT] = None
state = None state: Optional[StateGroupDataStore] = None
persist_events = None persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases: for database_config in hs.config.database.databases:
db_name = database_config.name db_name = database_config.name

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore from .user_erasure_store import UserErasureStore
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,7 +129,7 @@ class DataStore(
LockStore, LockStore,
SessionStore, SessionStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -28,6 +28,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer. `get_max_account_data_stream_id` which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):

View File

@ -15,7 +15,7 @@
import itertools import itertools
import logging import logging
from typing import Any, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams import BackfillStream, CachesStream
@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.logging import issue9533_logger from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
@ -26,11 +26,14 @@ from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(

View File

@ -15,7 +15,17 @@
# limitations under the License. # limitations under the License.
import abc import abc
import logging import logging
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore): class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies

View File

@ -14,7 +14,7 @@
import itertools import itertools
import logging import logging
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
oldest_pdu_in_federation_staging = Gauge( oldest_pdu_in_federation_staging = Gauge(
"synapse_federation_server_oldest_inbound_pdu_in_staging", "synapse_federation_server_oldest_inbound_pdu_in_staging",
"The age in seconds since we received the oldest pdu in the federation staging area", "The age in seconds since we received the oldest pdu in the federation staging area",
@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr import attr
@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn # These get correctly set by _find_stream_orderings_for_times_txn
@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr import attr
@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -76,7 +79,7 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore): class EventsBackgroundUpdatesStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(

View File

@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
if TYPE_CHECKING:
from synapse.server import HomeServer
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method" "media_repository_drop_index_wo_method"
) )
@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars""" """Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname

View File

@ -14,7 +14,7 @@
import calendar import calendar
import logging import logging
import time import time
from typing import Dict from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Collect metrics on the number of forward extremities that exist. # Collect metrics on the number of forward extremities that exist.
@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics. stats and prometheus metrics.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes # Read the extrems every 60 minutes

View File

@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp # Number of msec of granularity to store the monthly_active_user timestamp
@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only self._mau_stats_only = hs.config.server.mau_stats_only

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import abc import abc
import logging import logging
from typing import Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
@ -33,6 +33,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -75,7 +78,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -29,11 +29,14 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):

View File

@ -17,7 +17,7 @@ import collections
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -32,6 +32,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX from synapse.util.stringutils import MXC_REGEX
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore): class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config

View File

@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None: async def forget(self, user_id: str, room_id: str) -> None:

View File

@ -15,7 +15,7 @@
import logging import logging
import re import re
from collections import namedtuple from collections import namedtuple
from typing import Collection, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SearchEntry = namedtuple( SearchEntry = namedtuple(
@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search: if not hs.config.server.enable_search:
@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore): class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys): async def search_msgs(self, room_ids, search_term, keys):

View File

@ -15,7 +15,7 @@
import collections.abc import collections.abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional, Set from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -30,6 +30,9 @@ from synapse.types import StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,7 +56,7 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.""" """The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion: async def get_room_version(self, room_id: str) -> RoomVersion:
@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events. * `state_groups_state`: Maps state group to state events.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)

View File

@ -16,7 +16,7 @@
import logging import logging
from enum import Enum from enum import Enum
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import Counter from typing_extensions import Counter
@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# these fields track absolutes (e.g. total number of rooms on the server) # these fields track absolutes (e.g. total number of rooms on the server)
@ -93,7 +96,7 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore): class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname

View File

@ -14,7 +14,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
db_binary_type = memoryview db_binary_type = memoryview
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +60,7 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore): class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:

View File

@ -18,6 +18,7 @@ import itertools
import logging import logging
from collections import deque from collections import deque
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -56,6 +57,9 @@ from synapse.types import (
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The number of times we are recalculating the current state # The number of times we are recalculating the current state
@ -272,7 +276,7 @@ class EventsPersistenceStorage:
current state and forward extremity changes. current state and forward extremity changes.
""" """
def __init__(self, hs, stores: Databases): def __init__(self, hs: "HomeServer", stores: Databases):
# We ultimately want to split out the state store from the main store, # We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same # so we use separate variables here even though they point to the same
# store for now. # store for now.