Add additional type hints to the storage module. (#8980)
This commit is contained in:
parent
b8591899ab
commit
637282bb50
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to the base storage code.
|
10
mypy.ini
10
mypy.ini
|
@ -70,6 +70,9 @@ files =
|
||||||
synapse/server_notices,
|
synapse/server_notices,
|
||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
|
synapse/storage/__init__.py,
|
||||||
|
synapse/storage/_base.py,
|
||||||
|
synapse/storage/background_updates.py,
|
||||||
synapse/storage/databases/main/appservice.py,
|
synapse/storage/databases/main/appservice.py,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
synapse/storage/databases/main/pusher.py,
|
synapse/storage/databases/main/pusher.py,
|
||||||
|
@ -78,8 +81,15 @@ files =
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
synapse/storage/database.py,
|
synapse/storage/database.py,
|
||||||
synapse/storage/engines,
|
synapse/storage/engines,
|
||||||
|
synapse/storage/keys.py,
|
||||||
synapse/storage/persist_events.py,
|
synapse/storage/persist_events.py,
|
||||||
|
synapse/storage/prepare_database.py,
|
||||||
|
synapse/storage/purge_events.py,
|
||||||
|
synapse/storage/push_rule.py,
|
||||||
|
synapse/storage/relations.py,
|
||||||
|
synapse/storage/roommember.py,
|
||||||
synapse/storage/state.py,
|
synapse/storage/state.py,
|
||||||
|
synapse/storage/types.py,
|
||||||
synapse/storage/util,
|
synapse/storage/util,
|
||||||
synapse/streams,
|
synapse/streams,
|
||||||
synapse/types.py,
|
synapse/types.py,
|
||||||
|
|
|
@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
member_event_id: str,
|
member_event_id: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
room_state = await self.state_store.get_state_for_events([member_event_id])
|
room_state = await self.state_store.get_state_for_event(member_event_id)
|
||||||
|
|
||||||
room_state = room_state[member_event_id]
|
|
||||||
|
|
||||||
limit = pagin_config.limit if pagin_config else None
|
limit = pagin_config.limit if pagin_config else None
|
||||||
if limit is None:
|
if limit is None:
|
||||||
|
|
|
@ -554,7 +554,7 @@ class SyncHandler:
|
||||||
event.event_id, state_filter=state_filter
|
event.event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_ids = state_ids.copy()
|
state_ids = dict(state_ids)
|
||||||
state_ids[(event.type, event.state_key)] = event.event_id
|
state_ids[(event.type, event.state_key)] = event.event_id
|
||||||
return state_ids
|
return state_ids
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
|
||||||
data stores associated with them (e.g. the schema version tables), which are
|
data stores associated with them (e.g. the schema version tables), which are
|
||||||
stored in `synapse.storage.schema`.
|
stored in `synapse.storage.schema`.
|
||||||
"""
|
"""
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.storage.databases import Databases
|
from synapse.storage.databases import Databases
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
|
||||||
from synapse.storage.purge_events import PurgeEventsStorage
|
from synapse.storage.purge_events import PurgeEventsStorage
|
||||||
from synapse.storage.state import StateGroupStorage
|
from synapse.storage.state import StateGroupStorage
|
||||||
|
|
||||||
__all__ = ["DataStores", "DataStore"]
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Databases", "DataStore"]
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
"""The high level interfaces for talking to various storage layers.
|
"""The high level interfaces for talking to various storage layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, stores: Databases):
|
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||||
# We include the main data store here mainly so that we don't have to
|
# We include the main data store here mainly so that we don't have to
|
||||||
# rewrite all the existing code to split it into high vs low level
|
# rewrite all the existing code to split it into high vs low level
|
||||||
# interfaces.
|
# interfaces.
|
||||||
|
|
|
@ -17,14 +17,18 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
|
||||||
|
|
||||||
from synapse.storage.database import LoggingTransaction # noqa: F401
|
from synapse.storage.database import LoggingTransaction # noqa: F401
|
||||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.types import Collection, get_domain_from_id
|
from synapse.storage.types import Connection
|
||||||
|
from synapse.types import Collection, StreamToken, get_domain_from_id
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
per data store (and not one per physical database).
|
per data store (and not one per physical database).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = database.engine
|
self.database_engine = database.engine
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
self.rand = random.SystemRandom()
|
self.rand = random.SystemRandom()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: StreamToken,
|
||||||
|
rows: Iterable[Any],
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _invalidate_state_caches(self, room_id, members_changed):
|
def _invalidate_state_caches(
|
||||||
|
self, room_id: str, members_changed: Iterable[str]
|
||||||
|
) -> None:
|
||||||
"""Invalidates caches that are based on the current state, but does
|
"""Invalidates caches that are based on the current state, but does
|
||||||
not stream invalidations down replication.
|
not stream invalidations down replication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str): Room where state changed
|
room_id: Room where state changed
|
||||||
members_changed (iterable[str]): The user_ids of members that have
|
members_changed: The user_ids of members that have changed
|
||||||
changed
|
|
||||||
"""
|
"""
|
||||||
for host in {get_domain_from_id(u) for u in members_changed}:
|
for host in {get_domain_from_id(u) for u in members_changed}:
|
||||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||||
|
@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
|
|
||||||
def _attempt_to_invalidate_cache(
|
def _attempt_to_invalidate_cache(
|
||||||
self, cache_name: str, key: Optional[Collection[Any]]
|
self, cache_name: str, key: Optional[Collection[Any]]
|
||||||
):
|
) -> None:
|
||||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||||
where they may not have the cache.
|
where they may not have the cache.
|
||||||
|
@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
cache.invalidate(tuple(key))
|
cache.invalidate(tuple(key))
|
||||||
|
|
||||||
|
|
||||||
def db_to_json(db_content):
|
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||||
"""
|
"""
|
||||||
Take some data from a database row and return a JSON-decoded object.
|
Take some data from a database row and return a JSON-decoded object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_content (memoryview|buffer|bytes|bytearray|unicode)
|
db_content: The JSON-encoded contents from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The object decoded from JSON.
|
||||||
"""
|
"""
|
||||||
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
||||||
# cast to bytes to decode
|
# cast to bytes to decode
|
||||||
|
|
|
@ -12,29 +12,34 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.types import Connection
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
from . import engines
|
from . import engines
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BackgroundUpdatePerformance:
|
class BackgroundUpdatePerformance:
|
||||||
"""Tracks the how long a background update is taking to update its items"""
|
"""Tracks the how long a background update is taking to update its items"""
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.total_item_count = 0
|
self.total_item_count = 0
|
||||||
self.total_duration_ms = 0
|
self.total_duration_ms = 0.0
|
||||||
self.avg_item_count = 0
|
self.avg_item_count = 0.0
|
||||||
self.avg_duration_ms = 0
|
self.avg_duration_ms = 0.0
|
||||||
|
|
||||||
def update(self, item_count, duration_ms):
|
def update(self, item_count: int, duration_ms: float) -> None:
|
||||||
"""Update the stats after doing an update"""
|
"""Update the stats after doing an update"""
|
||||||
self.total_item_count += item_count
|
self.total_item_count += item_count
|
||||||
self.total_duration_ms += duration_ms
|
self.total_duration_ms += duration_ms
|
||||||
|
@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
|
||||||
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
|
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
|
||||||
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
|
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
|
||||||
|
|
||||||
def average_items_per_ms(self):
|
def average_items_per_ms(self) -> Optional[float]:
|
||||||
"""An estimate of how long it takes to do a single update.
|
"""An estimate of how long it takes to do a single update.
|
||||||
Returns:
|
Returns:
|
||||||
A duration in ms as a float
|
A duration in ms as a float
|
||||||
|
@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
|
||||||
# changes in how long the update process takes.
|
# changes in how long the update process takes.
|
||||||
return float(self.avg_item_count) / float(self.avg_duration_ms)
|
return float(self.avg_item_count) / float(self.avg_duration_ms)
|
||||||
|
|
||||||
def total_items_per_ms(self):
|
def total_items_per_ms(self) -> Optional[float]:
|
||||||
"""An estimate of how long it takes to do a single update.
|
"""An estimate of how long it takes to do a single update.
|
||||||
Returns:
|
Returns:
|
||||||
A duration in ms as a float
|
A duration in ms as a float
|
||||||
|
@ -83,21 +88,25 @@ class BackgroundUpdater:
|
||||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||||
|
|
||||||
def __init__(self, hs, database):
|
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
|
|
||||||
# if a background update is currently running, its name.
|
# if a background update is currently running, its name.
|
||||||
self._current_background_update = None # type: Optional[str]
|
self._current_background_update = None # type: Optional[str]
|
||||||
|
|
||||||
self._background_update_performance = {}
|
self._background_update_performance = (
|
||||||
self._background_update_handlers = {}
|
{}
|
||||||
|
) # type: Dict[str, BackgroundUpdatePerformance]
|
||||||
|
self._background_update_handlers = (
|
||||||
|
{}
|
||||||
|
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
|
||||||
self._all_done = False
|
self._all_done = False
|
||||||
|
|
||||||
def start_doing_background_updates(self):
|
def start_doing_background_updates(self) -> None:
|
||||||
run_as_background_process("background_updates", self.run_background_updates)
|
run_as_background_process("background_updates", self.run_background_updates)
|
||||||
|
|
||||||
async def run_background_updates(self, sleep=True):
|
async def run_background_updates(self, sleep: bool = True) -> None:
|
||||||
logger.info("Starting background schema updates")
|
logger.info("Starting background schema updates")
|
||||||
while True:
|
while True:
|
||||||
if sleep:
|
if sleep:
|
||||||
|
@ -148,7 +157,7 @@ class BackgroundUpdater:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def has_completed_background_update(self, update_name) -> bool:
|
async def has_completed_background_update(self, update_name: str) -> bool:
|
||||||
"""Check if the given background update has finished running.
|
"""Check if the given background update has finished running.
|
||||||
"""
|
"""
|
||||||
if self._all_done:
|
if self._all_done:
|
||||||
|
@ -173,8 +182,7 @@ class BackgroundUpdater:
|
||||||
Returns once some amount of work is done.
|
Returns once some amount of work is done.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
desired_duration_ms(float): How long we want to spend
|
desired_duration_ms: How long we want to spend updating.
|
||||||
updating.
|
|
||||||
Returns:
|
Returns:
|
||||||
True if we have finished running all the background updates, otherwise False
|
True if we have finished running all the background updates, otherwise False
|
||||||
"""
|
"""
|
||||||
|
@ -220,6 +228,7 @@ class BackgroundUpdater:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _do_background_update(self, desired_duration_ms: float) -> int:
|
async def _do_background_update(self, desired_duration_ms: float) -> int:
|
||||||
|
assert self._current_background_update is not None
|
||||||
update_name = self._current_background_update
|
update_name = self._current_background_update
|
||||||
logger.info("Starting update batch on background update '%s'", update_name)
|
logger.info("Starting update batch on background update '%s'", update_name)
|
||||||
|
|
||||||
|
@ -273,7 +282,11 @@ class BackgroundUpdater:
|
||||||
|
|
||||||
return len(self._background_update_performance)
|
return len(self._background_update_performance)
|
||||||
|
|
||||||
def register_background_update_handler(self, update_name, update_handler):
|
def register_background_update_handler(
|
||||||
|
self,
|
||||||
|
update_name: str,
|
||||||
|
update_handler: Callable[[JsonDict, int], Awaitable[int]],
|
||||||
|
):
|
||||||
"""Register a handler for doing a background update.
|
"""Register a handler for doing a background update.
|
||||||
|
|
||||||
The handler should take two arguments:
|
The handler should take two arguments:
|
||||||
|
@ -287,12 +300,12 @@ class BackgroundUpdater:
|
||||||
The handler is responsible for updating the progress of the update.
|
The handler is responsible for updating the progress of the update.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_name(str): The name of the update that this code handles.
|
update_name: The name of the update that this code handles.
|
||||||
update_handler(function): The function that does the update.
|
update_handler: The function that does the update.
|
||||||
"""
|
"""
|
||||||
self._background_update_handlers[update_name] = update_handler
|
self._background_update_handlers[update_name] = update_handler
|
||||||
|
|
||||||
def register_noop_background_update(self, update_name):
|
def register_noop_background_update(self, update_name: str) -> None:
|
||||||
"""Register a noop handler for a background update.
|
"""Register a noop handler for a background update.
|
||||||
|
|
||||||
This is useful when we previously did a background update, but no
|
This is useful when we previously did a background update, but no
|
||||||
|
@ -302,10 +315,10 @@ class BackgroundUpdater:
|
||||||
also be called to clear the update.
|
also be called to clear the update.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_name (str): Name of update
|
update_name: Name of update
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def noop_update(progress, batch_size):
|
async def noop_update(progress: JsonDict, batch_size: int) -> int:
|
||||||
await self._end_background_update(update_name)
|
await self._end_background_update(update_name)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -313,14 +326,14 @@ class BackgroundUpdater:
|
||||||
|
|
||||||
def register_background_index_update(
|
def register_background_index_update(
|
||||||
self,
|
self,
|
||||||
update_name,
|
update_name: str,
|
||||||
index_name,
|
index_name: str,
|
||||||
table,
|
table: str,
|
||||||
columns,
|
columns: Iterable[str],
|
||||||
where_clause=None,
|
where_clause: Optional[str] = None,
|
||||||
unique=False,
|
unique: bool = False,
|
||||||
psql_only=False,
|
psql_only: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""Helper for store classes to do a background index addition
|
"""Helper for store classes to do a background index addition
|
||||||
|
|
||||||
To use:
|
To use:
|
||||||
|
@ -332,19 +345,19 @@ class BackgroundUpdater:
|
||||||
2. In the Store constructor, call this method
|
2. In the Store constructor, call this method
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_name (str): update_name to register for
|
update_name: update_name to register for
|
||||||
index_name (str): name of index to add
|
index_name: name of index to add
|
||||||
table (str): table to add index to
|
table: table to add index to
|
||||||
columns (list[str]): columns/expressions to include in index
|
columns: columns/expressions to include in index
|
||||||
unique (bool): true to make a UNIQUE index
|
unique: true to make a UNIQUE index
|
||||||
psql_only: true to only create this index on psql databases (useful
|
psql_only: true to only create this index on psql databases (useful
|
||||||
for virtual sqlite tables)
|
for virtual sqlite tables)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_index_psql(conn):
|
def create_index_psql(conn: Connection) -> None:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
# postgres insists on autocommit for the index
|
# postgres insists on autocommit for the index
|
||||||
conn.set_session(autocommit=True)
|
conn.set_session(autocommit=True) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
@ -371,9 +384,9 @@ class BackgroundUpdater:
|
||||||
logger.debug("[SQL] %s", sql)
|
logger.debug("[SQL] %s", sql)
|
||||||
c.execute(sql)
|
c.execute(sql)
|
||||||
finally:
|
finally:
|
||||||
conn.set_session(autocommit=False)
|
conn.set_session(autocommit=False) # type: ignore
|
||||||
|
|
||||||
def create_index_sqlite(conn):
|
def create_index_sqlite(conn: Connection) -> None:
|
||||||
# Sqlite doesn't support concurrent creation of indexes.
|
# Sqlite doesn't support concurrent creation of indexes.
|
||||||
#
|
#
|
||||||
# We don't use partial indices on SQLite as it wasn't introduced
|
# We don't use partial indices on SQLite as it wasn't introduced
|
||||||
|
@ -399,7 +412,7 @@ class BackgroundUpdater:
|
||||||
c.execute(sql)
|
c.execute(sql)
|
||||||
|
|
||||||
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
||||||
runner = create_index_psql
|
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
|
||||||
elif psql_only:
|
elif psql_only:
|
||||||
runner = None
|
runner = None
|
||||||
else:
|
else:
|
||||||
|
@ -433,7 +446,9 @@ class BackgroundUpdater:
|
||||||
"background_updates", keyvalues={"update_name": update_name}
|
"background_updates", keyvalues={"update_name": update_name}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_update_progress(self, update_name: str, progress: dict):
|
async def _background_update_progress(
|
||||||
|
self, update_name: str, progress: dict
|
||||||
|
) -> None:
|
||||||
"""Update the progress of a background update
|
"""Update the progress of a background update
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -441,20 +456,22 @@ class BackgroundUpdater:
|
||||||
progress: The progress of the update.
|
progress: The progress of the update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"background_update_progress",
|
"background_update_progress",
|
||||||
self._background_update_progress_txn,
|
self._background_update_progress_txn,
|
||||||
update_name,
|
update_name,
|
||||||
progress,
|
progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _background_update_progress_txn(self, txn, update_name, progress):
|
def _background_update_progress_txn(
|
||||||
|
self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
|
||||||
|
) -> None:
|
||||||
"""Update the progress of a background update
|
"""Update the progress of a background update
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn(cursor): The transaction.
|
txn: The transaction.
|
||||||
update_name(str): The name of the background update task
|
update_name: The name of the background update task
|
||||||
progress(dict): The progress of the update.
|
progress: The progress of the update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
progress_json = json_encoder.encode(progress)
|
progress_json = json_encoder.encode(progress)
|
||||||
|
|
|
@ -17,11 +17,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from signedjson.types import VerifyKey
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class FetchKeyResult:
|
class FetchKeyResult:
|
||||||
verify_key = attr.ib() # VerifyKey: the key itself
|
verify_key = attr.ib(type=VerifyKey) # the key itself
|
||||||
valid_until_ts = attr.ib() # int: how long we can use this key for
|
valid_until_ts = attr.ib(type=int) # how long we can use this key for
|
||||||
|
|
|
@ -18,9 +18,10 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Optional, TextIO
|
from typing import Generator, Iterable, List, Optional, TextIO, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from typing_extensions import Counter as CounterType
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.storage.database import LoggingDatabaseConnection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
|
@ -70,7 +71,7 @@ def prepare_database(
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
database_engine: BaseDatabaseEngine,
|
database_engine: BaseDatabaseEngine,
|
||||||
config: Optional[HomeServerConfig],
|
config: Optional[HomeServerConfig],
|
||||||
databases: Collection[str] = ["main", "state"],
|
databases: Collection[str] = ("main", "state"),
|
||||||
):
|
):
|
||||||
"""Prepares a physical database for usage. Will either create all necessary tables
|
"""Prepares a physical database for usage. Will either create all necessary tables
|
||||||
or upgrade from an older schema version.
|
or upgrade from an older schema version.
|
||||||
|
@ -155,7 +156,9 @@ def prepare_database(
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _setup_new_database(cur, database_engine, databases):
|
def _setup_new_database(
|
||||||
|
cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
|
||||||
|
) -> None:
|
||||||
"""Sets up the physical database by finding a base set of "full schemas" and
|
"""Sets up the physical database by finding a base set of "full schemas" and
|
||||||
then applying any necessary deltas, including schemas from the given data
|
then applying any necessary deltas, including schemas from the given data
|
||||||
stores.
|
stores.
|
||||||
|
@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
|
||||||
folder as well those in the data stores specified.
|
folder as well those in the data stores specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cur (Cursor): a database cursor
|
cur: a database cursor
|
||||||
database_engine (DatabaseEngine)
|
database_engine
|
||||||
databases (list[str]): The names of the databases to instantiate
|
databases: The names of the databases to instantiate on the given physical database.
|
||||||
on the given physical database.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We're about to set up a brand new database so we check that its
|
# We're about to set up a brand new database so we check that its
|
||||||
|
@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
|
||||||
database_engine.check_new_database(cur)
|
database_engine.check_new_database(cur)
|
||||||
|
|
||||||
current_dir = os.path.join(dir_path, "schema", "full_schemas")
|
current_dir = os.path.join(dir_path, "schema", "full_schemas")
|
||||||
directory_entries = os.listdir(current_dir)
|
|
||||||
|
|
||||||
# First we find the highest full schema version we have
|
# First we find the highest full schema version we have
|
||||||
valid_versions = []
|
valid_versions = []
|
||||||
|
|
||||||
for filename in directory_entries:
|
for filename in os.listdir(current_dir):
|
||||||
try:
|
try:
|
||||||
ver = int(filename)
|
ver = int(filename)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
|
||||||
for database in databases
|
for database in databases
|
||||||
)
|
)
|
||||||
|
|
||||||
directory_entries = []
|
directory_entries = [] # type: List[_DirectoryListing]
|
||||||
for directory in directories:
|
for directory in directories:
|
||||||
directory_entries.extend(
|
directory_entries.extend(
|
||||||
_DirectoryListing(file_name, os.path.join(directory, file_name))
|
_DirectoryListing(file_name, os.path.join(directory, file_name))
|
||||||
|
@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
|
||||||
|
|
||||||
|
|
||||||
def _upgrade_existing_database(
|
def _upgrade_existing_database(
|
||||||
cur,
|
cur: Cursor,
|
||||||
current_version,
|
current_version: int,
|
||||||
applied_delta_files,
|
applied_delta_files: List[str],
|
||||||
upgraded,
|
upgraded: bool,
|
||||||
database_engine,
|
database_engine: BaseDatabaseEngine,
|
||||||
config,
|
config: Optional[HomeServerConfig],
|
||||||
databases,
|
databases: Collection[str],
|
||||||
is_empty=False,
|
is_empty: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""Upgrades an existing physical database.
|
"""Upgrades an existing physical database.
|
||||||
|
|
||||||
Delta files can either be SQL stored in *.sql files, or python modules
|
Delta files can either be SQL stored in *.sql files, or python modules
|
||||||
|
@ -323,21 +324,20 @@ def _upgrade_existing_database(
|
||||||
for a version before applying those in the next version.
|
for a version before applying those in the next version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cur (Cursor)
|
cur
|
||||||
current_version (int): The current version of the schema.
|
current_version: The current version of the schema.
|
||||||
applied_delta_files (list): A list of deltas that have already been
|
applied_delta_files: A list of deltas that have already been applied.
|
||||||
applied.
|
upgraded: Whether the current version was generated by having
|
||||||
upgraded (bool): Whether the current version was generated by having
|
|
||||||
applied deltas or from full schema file. If `True` the function
|
applied deltas or from full schema file. If `True` the function
|
||||||
will never apply delta files for the given `current_version`, since
|
will never apply delta files for the given `current_version`, since
|
||||||
the current_version wasn't generated by applying those delta files.
|
the current_version wasn't generated by applying those delta files.
|
||||||
database_engine (DatabaseEngine)
|
database_engine
|
||||||
config (synapse.config.homeserver.HomeServerConfig|None):
|
config:
|
||||||
None if we are initialising a blank database, otherwise the application
|
None if we are initialising a blank database, otherwise the application
|
||||||
config
|
config
|
||||||
databases (list[str]): The names of the databases to instantiate
|
databases: The names of the databases to instantiate
|
||||||
on the given physical database.
|
on the given physical database.
|
||||||
is_empty (bool): Is this a blank database? I.e. do we need to run the
|
is_empty: Is this a blank database? I.e. do we need to run the
|
||||||
upgrade portions of the delta scripts.
|
upgrade portions of the delta scripts.
|
||||||
"""
|
"""
|
||||||
if is_empty:
|
if is_empty:
|
||||||
|
@ -358,6 +358,7 @@ def _upgrade_existing_database(
|
||||||
if not is_empty and "main" in databases:
|
if not is_empty and "main" in databases:
|
||||||
from synapse.storage.databases.main import check_database_before_upgrade
|
from synapse.storage.databases.main import check_database_before_upgrade
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
check_database_before_upgrade(cur, database_engine, config)
|
check_database_before_upgrade(cur, database_engine, config)
|
||||||
|
|
||||||
start_ver = current_version
|
start_ver = current_version
|
||||||
|
@ -388,10 +389,10 @@ def _upgrade_existing_database(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Used to check if we have any duplicate file names
|
# Used to check if we have any duplicate file names
|
||||||
file_name_counter = Counter()
|
file_name_counter = Counter() # type: CounterType[str]
|
||||||
|
|
||||||
# Now find which directories have anything of interest.
|
# Now find which directories have anything of interest.
|
||||||
directory_entries = []
|
directory_entries = [] # type: List[_DirectoryListing]
|
||||||
for directory in directories:
|
for directory in directories:
|
||||||
logger.debug("Looking for schema deltas in %s", directory)
|
logger.debug("Looking for schema deltas in %s", directory)
|
||||||
try:
|
try:
|
||||||
|
@ -445,11 +446,11 @@ def _upgrade_existing_database(
|
||||||
|
|
||||||
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
||||||
with open(absolute_path) as python_file:
|
with open(absolute_path) as python_file:
|
||||||
module = imp.load_source(module_name, absolute_path, python_file)
|
module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
|
||||||
logger.info("Running script %s", relative_path)
|
logger.info("Running script %s", relative_path)
|
||||||
module.run_create(cur, database_engine)
|
module.run_create(cur, database_engine) # type: ignore
|
||||||
if not is_empty:
|
if not is_empty:
|
||||||
module.run_upgrade(cur, database_engine, config=config)
|
module.run_upgrade(cur, database_engine, config=config) # type: ignore
|
||||||
elif ext == ".pyc" or file_name == "__pycache__":
|
elif ext == ".pyc" or file_name == "__pycache__":
|
||||||
# Sometimes .pyc files turn up anyway even though we've
|
# Sometimes .pyc files turn up anyway even though we've
|
||||||
# disabled their generation; e.g. from distribution package
|
# disabled their generation; e.g. from distribution package
|
||||||
|
@ -497,14 +498,15 @@ def _upgrade_existing_database(
|
||||||
logger.info("Schema now up to date")
|
logger.info("Schema now up to date")
|
||||||
|
|
||||||
|
|
||||||
def _apply_module_schemas(txn, database_engine, config):
|
def _apply_module_schemas(
|
||||||
|
txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
|
||||||
|
) -> None:
|
||||||
"""Apply the module schemas for the dynamic modules, if any
|
"""Apply the module schemas for the dynamic modules, if any
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cur: database cursor
|
cur: database cursor
|
||||||
database_engine: synapse database engine class
|
database_engine:
|
||||||
config (synapse.config.homeserver.HomeServerConfig):
|
config: application config
|
||||||
application config
|
|
||||||
"""
|
"""
|
||||||
for (mod, _config) in config.password_providers:
|
for (mod, _config) in config.password_providers:
|
||||||
if not hasattr(mod, "get_db_schema_files"):
|
if not hasattr(mod, "get_db_schema_files"):
|
||||||
|
@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
|
def _apply_module_schema_files(
|
||||||
|
cur: Cursor,
|
||||||
|
database_engine: BaseDatabaseEngine,
|
||||||
|
modname: str,
|
||||||
|
names_and_streams: Iterable[Tuple[str, TextIO]],
|
||||||
|
) -> None:
|
||||||
"""Apply the module schemas for a single module
|
"""Apply the module schemas for a single module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cur: database cursor
|
cur: database cursor
|
||||||
database_engine: synapse database engine class
|
database_engine: synapse database engine class
|
||||||
modname (str): fully qualified name of the module
|
modname: fully qualified name of the module
|
||||||
names_and_streams (Iterable[(str, file)]): the names and streams of
|
names_and_streams: the names and streams of schemas to be applied
|
||||||
schemas to be applied
|
|
||||||
"""
|
"""
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
|
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
|
||||||
|
@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_statements(f):
|
def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
|
||||||
statement_buffer = ""
|
statement_buffer = ""
|
||||||
in_comment = False # If we're in a /* ... */ style comment
|
in_comment = False # If we're in a /* ... */ style comment
|
||||||
|
|
||||||
|
@ -594,17 +600,19 @@ def get_statements(f):
|
||||||
statement_buffer = statements[-1].strip()
|
statement_buffer = statements[-1].strip()
|
||||||
|
|
||||||
|
|
||||||
def executescript(txn, schema_path):
|
def executescript(txn: Cursor, schema_path: str) -> None:
|
||||||
with open(schema_path, "r") as f:
|
with open(schema_path, "r") as f:
|
||||||
execute_statements_from_stream(txn, f)
|
execute_statements_from_stream(txn, f)
|
||||||
|
|
||||||
|
|
||||||
def execute_statements_from_stream(cur: Cursor, f: TextIO):
|
def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
|
||||||
for statement in get_statements(f):
|
for statement in get_statements(f):
|
||||||
cur.execute(statement)
|
cur.execute(statement)
|
||||||
|
|
||||||
|
|
||||||
def _get_or_create_schema_state(txn, database_engine):
|
def _get_or_create_schema_state(
|
||||||
|
txn: Cursor, database_engine: BaseDatabaseEngine
|
||||||
|
) -> Optional[Tuple[int, List[str], bool]]:
|
||||||
# Bluntly try creating the schema_version tables.
|
# Bluntly try creating the schema_version tables.
|
||||||
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
|
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
|
||||||
executescript(txn, schema_path)
|
executescript(txn, schema_path)
|
||||||
|
@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
current_version = int(row[0]) if row else None
|
current_version = int(row[0]) if row else None
|
||||||
upgraded = bool(row[1]) if row else None
|
|
||||||
|
|
||||||
if current_version:
|
if current_version:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||||
(current_version,),
|
(current_version,),
|
||||||
)
|
)
|
||||||
applied_deltas = [d for d, in txn]
|
applied_deltas = [d for d, in txn]
|
||||||
|
upgraded = bool(row[1])
|
||||||
return current_version, applied_deltas, upgraded
|
return current_version, applied_deltas, upgraded
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -634,5 +642,5 @@ class _DirectoryListing:
|
||||||
`file_name` attr is kept first.
|
`file_name` attr is kept first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_name = attr.ib()
|
file_name = attr.ib(type=str)
|
||||||
absolute_path = attr.ib()
|
absolute_path = attr.ib(type=str)
|
||||||
|
|
|
@ -15,7 +15,12 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Set
|
from typing import TYPE_CHECKING, Set
|
||||||
|
|
||||||
|
from synapse.storage.databases import Databases
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -24,10 +29,10 @@ class PurgeEventsStorage:
|
||||||
"""High level interface for purging rooms and event history.
|
"""High level interface for purging rooms and event history.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, stores):
|
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||||
self.stores = stores
|
self.stores = stores
|
||||||
|
|
||||||
async def purge_room(self, room_id: str):
|
async def purge_room(self, room_id: str) -> None:
|
||||||
"""Deletes all record of a room
|
"""Deletes all record of a room
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -27,18 +29,18 @@ class PaginationChunk:
|
||||||
"""Returned by relation pagination APIs.
|
"""Returned by relation pagination APIs.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
chunk (list): The rows returned by pagination
|
chunk: The rows returned by pagination
|
||||||
next_batch (Any|None): Token to fetch next set of results with, if
|
next_batch: Token to fetch next set of results with, if
|
||||||
None then there are no more results.
|
None then there are no more results.
|
||||||
prev_batch (Any|None): Token to fetch previous set of results with, if
|
prev_batch: Token to fetch previous set of results with, if
|
||||||
None then there are no previous results.
|
None then there are no previous results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chunk = attr.ib()
|
chunk = attr.ib(type=List[JsonDict])
|
||||||
next_batch = attr.ib(default=None)
|
next_batch = attr.ib(type=Optional[Any], default=None)
|
||||||
prev_batch = attr.ib(default=None)
|
prev_batch = attr.ib(type=Optional[Any], default=None)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d = {"chunk": self.chunk}
|
d = {"chunk": self.chunk}
|
||||||
|
|
||||||
if self.next_batch:
|
if self.next_batch:
|
||||||
|
@ -59,25 +61,25 @@ class RelationPaginationToken:
|
||||||
boundaries of the chunk as pagination tokens.
|
boundaries of the chunk as pagination tokens.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
topological (int): The topological ordering of the boundary event
|
topological: The topological ordering of the boundary event
|
||||||
stream (int): The stream ordering of the boundary event.
|
stream: The stream ordering of the boundary event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
topological = attr.ib()
|
topological = attr.ib(type=int)
|
||||||
stream = attr.ib()
|
stream = attr.ib(type=int)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_string(string):
|
def from_string(string: str) -> "RelationPaginationToken":
|
||||||
try:
|
try:
|
||||||
t, s = string.split("-")
|
t, s = string.split("-")
|
||||||
return RelationPaginationToken(int(t), int(s))
|
return RelationPaginationToken(int(t), int(s))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SynapseError(400, "Invalid token")
|
raise SynapseError(400, "Invalid token")
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self) -> str:
|
||||||
return "%d-%d" % (self.topological, self.stream)
|
return "%d-%d" % (self.topological, self.stream)
|
||||||
|
|
||||||
def as_tuple(self):
|
def as_tuple(self) -> Tuple[Any, ...]:
|
||||||
return attr.astuple(self)
|
return attr.astuple(self)
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,23 +91,23 @@ class AggregationPaginationToken:
|
||||||
aggregation groups, we can just use them as our pagination token.
|
aggregation groups, we can just use them as our pagination token.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
count (int): The count of relations in the boundar group.
|
count: The count of relations in the boundary group.
|
||||||
stream (int): The MAX stream ordering in the boundary group.
|
stream: The MAX stream ordering in the boundary group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
count = attr.ib()
|
count = attr.ib(type=int)
|
||||||
stream = attr.ib()
|
stream = attr.ib(type=int)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_string(string):
|
def from_string(string: str) -> "AggregationPaginationToken":
|
||||||
try:
|
try:
|
||||||
c, s = string.split("-")
|
c, s = string.split("-")
|
||||||
return AggregationPaginationToken(int(c), int(s))
|
return AggregationPaginationToken(int(c), int(s))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SynapseError(400, "Invalid token")
|
raise SynapseError(400, "Invalid token")
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self) -> str:
|
||||||
return "%d-%d" % (self.count, self.stream)
|
return "%d-%d" % (self.count, self.stream)
|
||||||
|
|
||||||
def as_tuple(self):
|
def as_tuple(self) -> Tuple[Any, ...]:
|
||||||
return attr.astuple(self)
|
return attr.astuple(self)
|
||||||
|
|
|
@ -12,9 +12,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
from synapse.storage.databases import Databases
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Used for generic functions below
|
# Used for generic functions below
|
||||||
|
@ -330,10 +343,12 @@ class StateGroupStorage:
|
||||||
"""High level interface to fetching state for event.
|
"""High level interface to fetching state for event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, stores):
|
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||||
self.stores = stores
|
self.stores = stores
|
||||||
|
|
||||||
async def get_state_group_delta(self, state_group: int):
|
async def get_state_group_delta(
|
||||||
|
self, state_group: int
|
||||||
|
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||||
"""Given a state group try to return a previous group and a delta between
|
"""Given a state group try to return a previous group and a delta between
|
||||||
the old and the new.
|
the old and the new.
|
||||||
|
|
||||||
|
@ -341,8 +356,8 @@ class StateGroupStorage:
|
||||||
state_group: The state group used to retrieve state deltas.
|
state_group: The state group used to retrieve state deltas.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[int], Optional[StateMap[str]]]:
|
A tuple of the previous group and a state map of the event IDs which
|
||||||
(prev_group, delta_ids)
|
make up the delta between the old and new state groups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self.stores.state.get_state_group_delta(state_group)
|
return await self.stores.state.get_state_group_delta(state_group)
|
||||||
|
@ -436,7 +451,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
async def get_state_for_events(
|
async def get_state_for_events(
|
||||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> Dict[str, StateMap[EventBase]]:
|
||||||
"""Given a list of event_ids and type tuples, return a list of state
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
dicts for each event.
|
dicts for each event.
|
||||||
|
|
||||||
|
@ -472,7 +487,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
async def get_state_ids_for_events(
|
async def get_state_ids_for_events(
|
||||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> Dict[str, StateMap[str]]:
|
||||||
"""
|
"""
|
||||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||||
of the state events (as opposed to the events themselves)
|
of the state events (as opposed to the events themselves)
|
||||||
|
@ -500,7 +515,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
async def get_state_for_event(
|
async def get_state_for_event(
|
||||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> StateMap[EventBase]:
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
@ -516,7 +531,7 @@ class StateGroupStorage:
|
||||||
|
|
||||||
async def get_state_ids_for_event(
|
async def get_state_ids_for_event(
|
||||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||||
):
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue