mirror of
https://github.com/matrix-org/synapse.git
synced 2025-01-27 02:19:20 +00:00
Add missing type hints to synapse.replication. (#11938)
This commit is contained in:
parent
8c94b3abe9
commit
d0e78af35e
1
changelog.d/11938.misc
Normal file
1
changelog.d/11938.misc
Normal file
@ -0,0 +1 @@
|
||||
Add missing type hints to replication code.
|
3
mypy.ini
3
mypy.ini
@ -169,6 +169,9 @@ disallow_untyped_defs = True
|
||||
[mypy-synapse.push.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.replication.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.rest.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
|
||||
for table, column in extra_tables:
|
||||
self.advance(None, _load_current_id(db_conn, table, column))
|
||||
|
||||
def advance(self, instance_name: Optional[str], new_id: int):
|
||||
def advance(self, instance_name: Optional[str], new_id: int) -> None:
|
||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
|
@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
)
|
||||
|
||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||
async def insert_client_ip(
|
||||
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
|
||||
) -> None:
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user_id, access_token, ip)
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def _invalidate_caches_for_devices(self, token, rows):
|
||||
def _invalidate_caches_for_devices(
|
||||
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
|
||||
) -> None:
|
||||
for row in rows:
|
||||
# The entities are either user IDs (starting with '@') whose devices
|
||||
# have changed, or remote servers that we need to tell about
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
self._group_updates_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
def get_group_stream_token(self):
|
||||
def get_group_stream_token(self) -> int:
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
if stream_name == GroupServerStream.NAME:
|
||||
self._group_updates_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterable
|
||||
|
||||
from synapse.replication.tcp.streams import PushRulesStream
|
||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
@ -20,10 +21,12 @@ from .events import SlavedEventStore
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
def get_max_push_rules_stream_id(self):
|
||||
def get_max_push_rules_stream_id(self) -> int:
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
from synapse.replication.tcp.streams import PushersStream
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token, rows
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
if stream_name == PushersStream.NAME:
|
||||
self._pushers_id_gen.advance(instance_name, token) # type: ignore
|
||||
self._pushers_id_gen.advance(instance_name, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
@ -14,10 +14,12 @@
|
||||
"""A replication client for use by synapse workers.
|
||||
"""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IAddress, IConnector
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.federation import send_queue
|
||||
@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
def startedConnecting(self, connector: IConnector) -> None:
|
||||
logger.info("Connecting to replication: %r", connector.getDestination())
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
|
||||
logger.info("Connected to replication: %r", addr)
|
||||
return ClientReplicationStreamProtocol(
|
||||
self.hs,
|
||||
@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
|
||||
self.command_handler,
|
||||
)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
|
||||
logger.error("Lost replication conn: %r", reason)
|
||||
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
|
||||
logger.error("Failed to connect to replication: %r", reason)
|
||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||
|
||||
@ -131,7 +133,7 @@ class ReplicationDataHandler:
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
) -> None:
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
@ -252,14 +254,16 @@ class ReplicationDataHandler:
|
||||
# loop. (This maintains the order so no need to resort)
|
||||
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
|
||||
|
||||
async def on_position(self, stream_name: str, instance_name: str, token: int):
|
||||
async def on_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
await self.on_rdata(stream_name, instance_name, token, [])
|
||||
|
||||
# We poke the generic "replication" notifier to wake anything up that
|
||||
# may be streaming.
|
||||
self.notifier.notify_replication()
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
def on_remote_server_up(self, server: str) -> None:
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
# Let's wake up the transaction queue for the server in case we have
|
||||
@ -269,7 +273,7 @@ class ReplicationDataHandler:
|
||||
|
||||
async def wait_for_stream_position(
|
||||
self, instance_name: str, stream_name: str, position: int
|
||||
):
|
||||
) -> None:
|
||||
"""Wait until this instance has received updates up to and including
|
||||
the given stream position.
|
||||
"""
|
||||
@ -304,7 +308,7 @@ class ReplicationDataHandler:
|
||||
"Finished waiting for repl stream %r to reach %s", stream_name, position
|
||||
)
|
||||
|
||||
def stop_pusher(self, user_id, app_id, pushkey):
|
||||
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||
if not self._notify_pushers:
|
||||
return
|
||||
|
||||
@ -316,13 +320,13 @@ class ReplicationDataHandler:
|
||||
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||
pusher.on_stop()
|
||||
|
||||
async def start_pusher(self, user_id, app_id, pushkey):
|
||||
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||
if not self._notify_pushers:
|
||||
return
|
||||
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
logger.info("Starting pusher %r / %r", user_id, key)
|
||||
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
|
||||
|
||||
class FederationSenderHandler:
|
||||
@ -353,10 +357,12 @@ class FederationSenderHandler:
|
||||
|
||||
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
|
||||
|
||||
def wake_destination(self, server: str):
|
||||
def wake_destination(self, server: str) -> None:
|
||||
self.federation_sender.wake_destination(server)
|
||||
|
||||
async def process_replication_rows(self, stream_name, token, rows):
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, token: int, rows: list
|
||||
) -> None:
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
if stream_name == "federation":
|
||||
@ -384,11 +390,12 @@ class FederationSenderHandler:
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(host)
|
||||
|
||||
async def _on_new_receipts(self, rows):
|
||||
async def _on_new_receipts(
|
||||
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
|
||||
new receipts to be processed
|
||||
rows: new receipts to be processed
|
||||
"""
|
||||
for receipt in rows:
|
||||
# we only want to send on receipts for our own users
|
||||
@ -408,7 +415,7 @@ class FederationSenderHandler:
|
||||
)
|
||||
await self.federation_sender.send_read_receipt(receipt_info)
|
||||
|
||||
async def update_token(self, token):
|
||||
async def update_token(self, token: int) -> None:
|
||||
"""Update the record of where we have processed to in the federation stream.
|
||||
|
||||
Called after we have processed a an update received over replication. Sends
|
||||
@ -428,7 +435,7 @@ class FederationSenderHandler:
|
||||
|
||||
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
|
||||
|
||||
async def _save_and_send_ack(self):
|
||||
async def _save_and_send_ack(self) -> None:
|
||||
"""Save the current federation position in the database and send an ACK
|
||||
to master with where we're up to.
|
||||
"""
|
||||
|
@ -18,12 +18,15 @@ allowed to be sent by which side.
|
||||
"""
|
||||
import abc
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
from typing import Optional, Tuple, Type, TypeVar
|
||||
|
||||
from synapse.replication.tcp.streams._base import StreamRow
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="Command")
|
||||
|
||||
|
||||
class Command(metaclass=abc.ABCMeta):
|
||||
"""The base command class.
|
||||
@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type[T], line: str) -> T:
|
||||
"""Deserialises a line from the wire into this command. `line` does not
|
||||
include the command.
|
||||
"""
|
||||
@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
|
||||
prefix.
|
||||
"""
|
||||
|
||||
def get_logcontext_id(self):
|
||||
def get_logcontext_id(self) -> str:
|
||||
"""Get a suitable string for the logcontext when processing this command"""
|
||||
|
||||
# by default, we just use the command name.
|
||||
return self.NAME
|
||||
|
||||
|
||||
SC = TypeVar("SC", bound="_SimpleCommand")
|
||||
|
||||
|
||||
class _SimpleCommand(Command):
|
||||
"""An implementation of Command whose argument is just a 'data' string."""
|
||||
|
||||
def __init__(self, data):
|
||||
def __init__(self, data: str):
|
||||
self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type[SC], line: str) -> SC:
|
||||
return cls(line)
|
||||
|
||||
def to_line(self) -> str:
|
||||
@ -109,14 +115,16 @@ class RdataCommand(Command):
|
||||
|
||||
NAME = "RDATA"
|
||||
|
||||
def __init__(self, stream_name, instance_name, token, row):
|
||||
def __init__(
|
||||
self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
|
||||
):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
|
||||
stream_name, instance_name, token, row_json = line.split(" ", 3)
|
||||
return cls(
|
||||
stream_name,
|
||||
@ -125,7 +133,7 @@ class RdataCommand(Command):
|
||||
json_decoder.decode(row_json),
|
||||
)
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return " ".join(
|
||||
(
|
||||
self.stream_name,
|
||||
@ -135,7 +143,7 @@ class RdataCommand(Command):
|
||||
)
|
||||
)
|
||||
|
||||
def get_logcontext_id(self):
|
||||
def get_logcontext_id(self) -> str:
|
||||
return "RDATA-" + self.stream_name
|
||||
|
||||
|
||||
@ -164,18 +172,20 @@ class PositionCommand(Command):
|
||||
|
||||
NAME = "POSITION"
|
||||
|
||||
def __init__(self, stream_name, instance_name, prev_token, new_token):
|
||||
def __init__(
|
||||
self, stream_name: str, instance_name: str, prev_token: int, new_token: int
|
||||
):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.prev_token = prev_token
|
||||
self.new_token = new_token
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
|
||||
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
|
||||
return cls(stream_name, instance_name, int(prev_token), int(new_token))
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return " ".join(
|
||||
(
|
||||
self.stream_name,
|
||||
@ -218,14 +228,14 @@ class ReplicateCommand(Command):
|
||||
|
||||
NAME = "REPLICATE"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type[T], line: str) -> T:
|
||||
return cls()
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@ -247,14 +257,16 @@ class UserSyncCommand(Command):
|
||||
|
||||
NAME = "USER_SYNC"
|
||||
|
||||
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
def __init__(
|
||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||
):
|
||||
self.instance_id = instance_id
|
||||
self.user_id = user_id
|
||||
self.is_syncing = is_syncing
|
||||
self.last_sync_ms = last_sync_ms
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
|
||||
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
||||
|
||||
if state not in ("start", "end"):
|
||||
@ -262,7 +274,7 @@ class UserSyncCommand(Command):
|
||||
|
||||
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return " ".join(
|
||||
(
|
||||
self.instance_id,
|
||||
@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
|
||||
|
||||
NAME = "CLEAR_USER_SYNC"
|
||||
|
||||
def __init__(self, instance_id):
|
||||
def __init__(self, instance_id: str):
|
||||
self.instance_id = instance_id
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(
|
||||
cls: Type["ClearUserSyncsCommand"], line: str
|
||||
) -> "ClearUserSyncsCommand":
|
||||
return cls(line)
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
|
||||
@ -316,7 +330,9 @@ class FederationAckCommand(Command):
|
||||
self.token = token
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line: str) -> "FederationAckCommand":
|
||||
def from_line(
|
||||
cls: Type["FederationAckCommand"], line: str
|
||||
) -> "FederationAckCommand":
|
||||
instance_name, token = line.split(" ")
|
||||
return cls(instance_name, int(token))
|
||||
|
||||
@ -334,7 +350,15 @@ class UserIpCommand(Command):
|
||||
|
||||
NAME = "USER_IP"
|
||||
|
||||
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
device_id: str,
|
||||
last_seen: int,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.access_token = access_token
|
||||
self.ip = ip
|
||||
@ -343,14 +367,14 @@ class UserIpCommand(Command):
|
||||
self.last_seen = last_seen
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
|
||||
user_id, jsn = line.split(" ", 1)
|
||||
|
||||
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
|
||||
|
||||
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
|
||||
def to_line(self):
|
||||
def to_line(self) -> str:
|
||||
return (
|
||||
self.user_id
|
||||
+ " "
|
||||
|
@ -261,7 +261,7 @@ class ReplicationCommandHandler:
|
||||
"process-replication-data", self._unsafe_process_queue, stream_name
|
||||
)
|
||||
|
||||
async def _unsafe_process_queue(self, stream_name: str):
|
||||
async def _unsafe_process_queue(self, stream_name: str) -> None:
|
||||
"""Processes the command queue for the given stream, until it is empty
|
||||
|
||||
Does not check if there is already a thread processing the queue, hence "unsafe"
|
||||
@ -294,7 +294,7 @@ class ReplicationCommandHandler:
|
||||
# This shouldn't be possible
|
||||
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
|
||||
|
||||
def start_replication(self, hs: "HomeServer"):
|
||||
def start_replication(self, hs: "HomeServer") -> None:
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
@ -345,10 +345,10 @@ class ReplicationCommandHandler:
|
||||
"""Get a list of streams that this instances replicates."""
|
||||
return self._streams_to_replicate
|
||||
|
||||
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
||||
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: IReplicationConnection):
|
||||
def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
|
||||
"""Send current position of all streams this process is source of to
|
||||
the connection.
|
||||
"""
|
||||
@ -392,7 +392,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
def on_FEDERATION_ACK(
|
||||
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||
):
|
||||
) -> None:
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
@ -408,7 +408,7 @@ class ReplicationCommandHandler:
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _handle_user_ip(self, cmd: UserIpCommand):
|
||||
async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
|
||||
await self._store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
@ -421,7 +421,7 @@ class ReplicationCommandHandler:
|
||||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
||||
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
@ -497,7 +497,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
) -> None:
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
@ -512,7 +512,7 @@ class ReplicationCommandHandler:
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
||||
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None:
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
@ -581,7 +581,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
def on_REMOTE_SERVER_UP(
|
||||
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
) -> None:
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
@ -604,7 +604,7 @@ class ReplicationCommandHandler:
|
||||
# between two instances, but that is not currently supported).
|
||||
self.send_command(cmd, ignore_conn=conn)
|
||||
|
||||
def new_connection(self, connection: IReplicationConnection):
|
||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||
"""Called when we have a new connection."""
|
||||
self._connections.append(connection)
|
||||
|
||||
@ -631,7 +631,7 @@ class ReplicationCommandHandler:
|
||||
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||
)
|
||||
|
||||
def lost_connection(self, connection: IReplicationConnection):
|
||||
def lost_connection(self, connection: IReplicationConnection) -> None:
|
||||
"""Called when a connection is closed/lost."""
|
||||
# we no longer need _streams_by_connection for this connection.
|
||||
streams = self._streams_by_connection.pop(connection, None)
|
||||
@ -653,7 +653,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
def send_command(
|
||||
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||
):
|
||||
) -> None:
|
||||
"""Send a command to all connected connections.
|
||||
|
||||
Args:
|
||||
@ -680,7 +680,7 @@ class ReplicationCommandHandler:
|
||||
else:
|
||||
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
||||
|
||||
def send_federation_ack(self, token: int):
|
||||
def send_federation_ack(self, token: int) -> None:
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
@ -688,7 +688,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
def send_user_sync(
|
||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||
):
|
||||
) -> None:
|
||||
"""Poke the master that a user has started/stopped syncing."""
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
@ -702,15 +702,15 @@ class ReplicationCommandHandler:
|
||||
user_agent: str,
|
||||
device_id: str,
|
||||
last_seen: int,
|
||||
):
|
||||
) -> None:
|
||||
"""Tell the master that the user made a request."""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
def send_remote_server_up(self, server: str) -> None:
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def stream_update(self, stream_name: str, token: str, data: Any):
|
||||
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
|
@ -49,7 +49,7 @@ import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, Collection, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Collection, List, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
from zope.interface import Interface, implementer
|
||||
@ -123,7 +123,7 @@ class ConnectionStates:
|
||||
class IReplicationConnection(Interface):
|
||||
"""An interface for replication connections."""
|
||||
|
||||
def send_command(cmd: Command):
|
||||
def send_command(cmd: Command) -> None:
|
||||
"""Send the command down the connection"""
|
||||
|
||||
|
||||
@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
"replication-conn", self.conn_id
|
||||
)
|
||||
|
||||
def connectionMade(self):
|
||||
def connectionMade(self) -> None:
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
|
||||
self.state = ConnectionStates.ESTABLISHED
|
||||
@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
# Always send the initial PING so that the other side knows that they
|
||||
# can time us out.
|
||||
self.send_command(PingCommand(self.clock.time_msec()))
|
||||
self.send_command(PingCommand(str(self.clock.time_msec())))
|
||||
|
||||
self.command_handler.new_connection(self)
|
||||
|
||||
def send_ping(self):
|
||||
def send_ping(self) -> None:
|
||||
"""Periodically sends a ping and checks if we should close the connection
|
||||
due to the other side timing out.
|
||||
"""
|
||||
@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.transport.abortConnection()
|
||||
else:
|
||||
if now - self.last_sent_command >= PING_TIME:
|
||||
self.send_command(PingCommand(now))
|
||||
self.send_command(PingCommand(str(now)))
|
||||
|
||||
if (
|
||||
self.received_ping
|
||||
@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
)
|
||||
self.send_error("ping timeout")
|
||||
|
||||
def lineReceived(self, line: bytes):
|
||||
def lineReceived(self, line: bytes) -> None:
|
||||
"""Called when we've received a line"""
|
||||
with PreserveLoggingContext(self._logging_context):
|
||||
self._parse_and_dispatch_line(line)
|
||||
|
||||
def _parse_and_dispatch_line(self, line: bytes):
|
||||
def _parse_and_dispatch_line(self, line: bytes) -> None:
|
||||
if line.strip() == "":
|
||||
# Ignore blank lines
|
||||
return
|
||||
@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
if not handled:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
logger.warning("[%s] Closing connection", self.id())
|
||||
self.time_we_closed = self.clock.time_msec()
|
||||
assert self.transport is not None
|
||||
self.transport.loseConnection()
|
||||
self.on_connection_closed()
|
||||
|
||||
def send_error(self, error_string, *args):
|
||||
def send_error(self, error_string: str, *args: Any) -> None:
|
||||
"""Send an error to remote and close the connection."""
|
||||
self.send_command(ErrorCommand(error_string % args))
|
||||
self.close()
|
||||
|
||||
def send_command(self, cmd, do_buffer=True):
|
||||
def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
|
||||
"""Send a command if connection has been established.
|
||||
|
||||
Args:
|
||||
cmd (Command)
|
||||
do_buffer (bool): Whether to buffer the message or always attempt
|
||||
cmd
|
||||
do_buffer: Whether to buffer the message or always attempt
|
||||
to send the command. This is mostly used to send an error
|
||||
message if we're about to close the connection due our buffers
|
||||
becoming full.
|
||||
@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
self.last_sent_command = self.clock.time_msec()
|
||||
|
||||
def _queue_command(self, cmd):
|
||||
def _queue_command(self, cmd: Command) -> None:
|
||||
"""Queue the command until the connection is ready to write to again."""
|
||||
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
||||
self.pending_commands.append(cmd)
|
||||
@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
|
||||
self.close()
|
||||
|
||||
def _send_pending_commands(self):
|
||||
def _send_pending_commands(self) -> None:
|
||||
"""Send any queued commandes"""
|
||||
pending = self.pending_commands
|
||||
self.pending_commands = []
|
||||
for cmd in pending:
|
||||
self.send_command(cmd)
|
||||
|
||||
def on_PING(self, line):
|
||||
def on_PING(self, cmd: PingCommand) -> None:
|
||||
self.received_ping = True
|
||||
|
||||
def on_ERROR(self, cmd):
|
||||
def on_ERROR(self, cmd: ErrorCommand) -> None:
|
||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||
|
||||
def pauseProducing(self):
|
||||
def pauseProducing(self) -> None:
|
||||
"""This is called when both the kernel send buffer and the twisted
|
||||
tcp connection send buffers have become full.
|
||||
|
||||
@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
logger.info("[%s] Pause producing", self.id())
|
||||
self.state = ConnectionStates.PAUSED
|
||||
|
||||
def resumeProducing(self):
|
||||
def resumeProducing(self) -> None:
|
||||
"""The remote has caught up after we started buffering!"""
|
||||
logger.info("[%s] Resume producing", self.id())
|
||||
self.state = ConnectionStates.ESTABLISHED
|
||||
self._send_pending_commands()
|
||||
|
||||
def stopProducing(self):
|
||||
def stopProducing(self) -> None:
|
||||
"""We're never going to send any more data (normally because either
|
||||
we or the remote has closed the connection)
|
||||
"""
|
||||
logger.info("[%s] Stop producing", self.id())
|
||||
self.on_connection_closed()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
|
||||
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
||||
if isinstance(reason, Failure):
|
||||
assert reason.type is not None
|
||||
connection_close_counter.labels(reason.type.__name__).inc()
|
||||
else:
|
||||
connection_close_counter.labels(reason.__class__.__name__).inc()
|
||||
connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable]
|
||||
|
||||
try:
|
||||
# Remove us from list of connections to be monitored
|
||||
@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
self.on_connection_closed()
|
||||
|
||||
def on_connection_closed(self):
|
||||
def on_connection_closed(self) -> None:
|
||||
logger.info("[%s] Connection was closed", self.id())
|
||||
|
||||
self.state = ConnectionStates.CLOSED
|
||||
@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# the sentinel context is now active, which may not be correct.
|
||||
# PreserveLoggingContext() will restore the correct logging context.
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
addr = None
|
||||
if self.transport:
|
||||
addr = str(self.transport.getPeer())
|
||||
@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
addr,
|
||||
)
|
||||
|
||||
def id(self):
|
||||
def id(self) -> str:
|
||||
return "%s-%s" % (self.name, self.conn_id)
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
def lineLengthExceeded(self, line: str) -> None:
|
||||
"""Called when we receive a line that is above the maximum line length"""
|
||||
self.send_error("Line length exceeded")
|
||||
|
||||
@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
|
||||
self.server_name = server_name
|
||||
|
||||
def connectionMade(self):
|
||||
def connectionMade(self) -> None:
|
||||
self.send_command(ServerCommand(self.server_name))
|
||||
super().connectionMade()
|
||||
|
||||
def on_NAME(self, cmd):
|
||||
def on_NAME(self, cmd: NameCommand) -> None:
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.client_name = client_name
|
||||
self.server_name = server_name
|
||||
|
||||
def connectionMade(self):
|
||||
def connectionMade(self) -> None:
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
super().connectionMade()
|
||||
|
||||
# Once we've connected subscribe to the necessary streams
|
||||
self.replicate()
|
||||
|
||||
def on_SERVER(self, cmd):
|
||||
def on_SERVER(self, cmd: ServerCommand) -> None:
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Send the subscription request to the server"""
|
||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||
|
||||
@ -529,7 +529,7 @@ pending_commands = LaterGauge(
|
||||
)
|
||||
|
||||
|
||||
def transport_buffer_size(protocol):
|
||||
def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
|
||||
if protocol.transport:
|
||||
size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
|
||||
return size
|
||||
@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
|
||||
)
|
||||
|
||||
|
||||
def transport_kernel_read_buffer_size(protocol, read=True):
|
||||
def transport_kernel_read_buffer_size(
|
||||
protocol: BaseReplicationStreamProtocol, read: bool = True
|
||||
) -> int:
|
||||
SIOCINQ = 0x541B
|
||||
SIOCOUTQ = 0x5411
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import logging
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
|
||||
|
||||
import attr
|
||||
import txredisapi
|
||||
@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]):
|
||||
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
|
||||
return self.constant
|
||||
|
||||
def __set__(self, obj: Optional[T], value: V):
|
||||
def __set__(self, obj: Optional[T], value: V) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
synapse_stream_name: str
|
||||
synapse_outbound_redis_connection: txredisapi.RedisProtocol
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||
@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
"replication_command_handler"
|
||||
)
|
||||
|
||||
def connectionMade(self):
|
||||
def connectionMade(self) -> None:
|
||||
logger.info("Connected to redis")
|
||||
super().connectionMade()
|
||||
run_as_background_process("subscribe-replication", self._send_subscribe)
|
||||
|
||||
async def _send_subscribe(self):
|
||||
async def _send_subscribe(self) -> None:
|
||||
# it's important to make sure that we only send the REPLICATE command once we
|
||||
# have successfully subscribed to the stream - otherwise we might miss the
|
||||
# POSITION response sent back by the other end.
|
||||
@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||
self.synapse_handler.send_positions_to_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
def messageReceived(self, pattern: str, channel: str, message: str) -> None:
|
||||
"""Received a message from redis."""
|
||||
with PreserveLoggingContext(self._logging_context):
|
||||
self._parse_and_dispatch_message(message)
|
||||
|
||||
def _parse_and_dispatch_message(self, message: str):
|
||||
def _parse_and_dispatch_message(self, message: str) -> None:
|
||||
if message.strip() == "":
|
||||
# Ignore blank lines
|
||||
return
|
||||
@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||
)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
|
||||
logger.info("Lost connection to redis")
|
||||
super().connectionLost(reason)
|
||||
self.synapse_handler.lost_connection(self)
|
||||
@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
# the sentinel context is now active, which may not be correct.
|
||||
# PreserveLoggingContext() will restore the correct logging context.
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
def send_command(self, cmd: Command) -> None:
|
||||
"""Send a command if connection has been established.
|
||||
|
||||
Args:
|
||||
cmd (Command)
|
||||
cmd: The command to send
|
||||
"""
|
||||
run_as_background_process(
|
||||
"send-cmd", self._async_send_command, cmd, bg_start_span=False
|
||||
)
|
||||
|
||||
async def _async_send_command(self, cmd: Command):
|
||||
async def _async_send_command(self, cmd: Command) -> None:
|
||||
"""Encode a replication command and send it over our outbound connection"""
|
||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||
if "\n" in string:
|
||||
@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||
|
||||
@wrap_as_background_process("redis_ping")
|
||||
async def _send_ping(self):
|
||||
async def _send_ping(self) -> None:
|
||||
for connection in self.pool:
|
||||
try:
|
||||
await make_deferred_yieldable(connection.ping())
|
||||
@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
|
||||
# it's rubbish. We add our own here.
|
||||
|
||||
def startedConnecting(self, connector: IConnector):
|
||||
def startedConnecting(self, connector: IConnector) -> None:
|
||||
logger.info(
|
||||
"Connecting to redis server %s", format_address(connector.getDestination())
|
||||
)
|
||||
super().startedConnecting(connector)
|
||||
|
||||
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
|
||||
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
|
||||
logger.info(
|
||||
"Connection to redis server %s failed: %s",
|
||||
format_address(connector.getDestination()),
|
||||
@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
)
|
||||
super().clientConnectionFailed(connector, reason)
|
||||
|
||||
def clientConnectionLost(self, connector: IConnector, reason: Failure):
|
||||
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
|
||||
logger.info(
|
||||
"Connection to redis server %s lost: %s",
|
||||
format_address(connector.getDestination()),
|
||||
@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||
|
||||
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
|
||||
p = super().buildProtocol(addr)
|
||||
p = cast(RedisSubscriber, p)
|
||||
|
||||
|
@ -16,16 +16,18 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import PositionCommand
|
||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import EventsStream
|
||||
from synapse.replication.tcp.streams._base import StreamRow, Token
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory):
|
||||
# listener config again or always starting a `ReplicationStreamer`.)
|
||||
hs.get_replication_streamer()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol:
|
||||
return ServerReplicationStreamProtocol(
|
||||
self.server_name, self.clock, self.command_handler
|
||||
)
|
||||
@ -105,7 +107,7 @@ class ReplicationStreamer:
|
||||
if any(EventsStream.NAME == s.NAME for s in self.streams):
|
||||
self.clock.looping_call(self.on_notifier_poke, 1000)
|
||||
|
||||
def on_notifier_poke(self):
|
||||
def on_notifier_poke(self) -> None:
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
connections if there are.
|
||||
|
||||
@ -137,7 +139,7 @@ class ReplicationStreamer:
|
||||
|
||||
run_as_background_process("replication_notifier", self._run_notifier_loop)
|
||||
|
||||
async def _run_notifier_loop(self):
|
||||
async def _run_notifier_loop(self) -> None:
|
||||
self.is_looping = True
|
||||
|
||||
try:
|
||||
@ -238,7 +240,9 @@ class ReplicationStreamer:
|
||||
self.is_looping = False
|
||||
|
||||
|
||||
def _batch_updates(updates):
|
||||
def _batch_updates(
|
||||
updates: List[Tuple[Token, StreamRow]]
|
||||
) -> List[Tuple[Optional[Token], StreamRow]]:
|
||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||
None for all rows where the next row has the same token. This is used to
|
||||
implement batching.
|
||||
@ -254,7 +258,7 @@ def _batch_updates(updates):
|
||||
if not updates:
|
||||
return []
|
||||
|
||||
new_updates = []
|
||||
new_updates: List[Tuple[Optional[Token], StreamRow]] = []
|
||||
for i, update in enumerate(updates[:-1]):
|
||||
if update[0] == updates[i + 1][0]:
|
||||
new_updates.append((None, update[1]))
|
||||
|
@ -90,7 +90,7 @@ class Stream:
|
||||
ROW_TYPE: Any = None
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row: StreamRow):
|
||||
def parse_row(cls, row: StreamRow) -> Any:
|
||||
"""Parse a row received over replication
|
||||
|
||||
By default, assumes that the row data is an array object and passes its contents
|
||||
@ -139,7 +139,7 @@ class Stream:
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token(self.local_instance_name)
|
||||
|
||||
def discard_updates_and_advance(self):
|
||||
def discard_updates_and_advance(self) -> None:
|
||||
"""Called when the stream should advance but the updates would be discarded,
|
||||
e.g. when there are no currently connected workers.
|
||||
"""
|
||||
@ -200,7 +200,7 @@ def current_token_without_instance(
|
||||
return lambda instance_name: current_token()
|
||||
|
||||
|
||||
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
||||
def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
|
@ -13,12 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import heapq
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
Stream,
|
||||
StreamRow,
|
||||
StreamUpdateResult,
|
||||
Token,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@ -58,6 +62,9 @@ class EventsStreamRow:
|
||||
data: "BaseEventsStreamRow"
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseEventsStreamRow")
|
||||
|
||||
|
||||
class BaseEventsStreamRow:
|
||||
"""Base class for rows to be sent in the events stream.
|
||||
|
||||
@ -68,7 +75,7 @@ class BaseEventsStreamRow:
|
||||
TypeId: str
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data):
|
||||
def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
|
||||
"""Parse the data from the replication stream into a row.
|
||||
|
||||
By default we just call the constructor with the data list as arguments
|
||||
@ -221,7 +228,7 @@ class EventsStream(Stream):
|
||||
return updates, upper_limit, limited
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
(typ, data) = row
|
||||
data = TypeToRow[typ].from_data(data)
|
||||
return EventsStreamRow(typ, data)
|
||||
def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
|
||||
(typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
|
||||
event_stream_row_data = TypeToRow[typ].from_data(data)
|
||||
return EventsStreamRow(typ, event_stream_row_data)
|
||||
|
@ -16,8 +16,7 @@ import itertools
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from netaddr import valid_ipv6
|
||||
|
||||
@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||
containing those items.
|
||||
|
||||
Otherwise, return the stringification of a a list with the first maxitems items,
|
||||
Otherwise, return the stringification of a list with the first maxitems items,
|
||||
followed by "...".
|
||||
|
||||
Args:
|
||||
|
@ -14,6 +14,7 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
|
||||
None
|
||||
IPv4Address("TCP", "127.0.0.1", 0)
|
||||
)
|
||||
|
||||
# Make a new HomeServer object for the worker
|
||||
@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.clock,
|
||||
repl_handler,
|
||||
)
|
||||
server = self.server_factory.buildProtocol(None)
|
||||
server = self.server_factory.buildProtocol(
|
||||
IPv4Address("TCP", "127.0.0.1", 0)
|
||||
)
|
||||
|
||||
client_transport = FakeTransport(server, self.reactor)
|
||||
client.makeConnection(client_transport)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.interfaces import IProtocol
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
|
||||
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
||||
"""Create a new direct TCP replication connection"""
|
||||
|
||||
proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
||||
proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
|
||||
transport = StringTransport()
|
||||
proto.makeConnection(transport)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user