Separate PostgresEngine into Psycopg2Engine and PsycopgEngine.

This commit is contained in:
Patrick Cloke 2023-11-09 15:11:52 -05:00
parent f725dc58ed
commit e32f49a24b
11 changed files with 259 additions and 77 deletions

View File

@ -1390,7 +1390,7 @@ def main() -> None:
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'\n")
sys.exit(2)
if postgres_config["name"] != "psycopg2":
if postgres_config["name"] not in ("psycopg", "psycopg2"):
sys.stderr.write("Database must use the 'psycopg2' connector.\n")
sys.exit(3)

View File

@ -50,7 +50,7 @@ class DatabaseConnectionConfig:
def __init__(self, name: str, db_config: dict):
db_engine = db_config.get("name", "sqlite3")
if db_engine not in ("sqlite3", "psycopg2"):
if db_engine not in ("sqlite3", "psycopg2", "psycopg"):
raise ConfigError("Unsupported database type %r" % (db_engine,))
if db_engine == "sqlite3":

View File

@ -14,6 +14,8 @@
from typing import Any, Mapping, NoReturn
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
# The classes `PostgresEngine` and `Sqlite3Engine` must always be importable, because
# we use `isinstance(engine, PostgresEngine)` to write different queries for postgres
@ -21,16 +23,27 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
# installed. To account for this, create dummy classes on import failure so we can
# still run `isinstance()` checks.
try:
from .postgres import PostgresEngine
from .psycopg2 import Psycopg2Engine
except ImportError:
class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef]
class Psycopg2Engine(BaseDatabaseEngine): # type: ignore[no-redef]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
)
try:
from .psycopg import PsycopgEngine
except ImportError:
class PsycopgEngine(BaseDatabaseEngine): # type: ignore[no-redef]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- psycopg module is not installed"
)
try:
from .sqlite import Sqlite3Engine
except ImportError:
@ -49,7 +62,10 @@ def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
return Sqlite3Engine(database_config)
if name == "psycopg2":
return PostgresEngine(database_config)
return Psycopg2Engine(database_config)
if name == "psycopg":
return PsycopgEngine(database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))

View File

@ -33,9 +33,12 @@ class IncorrectDatabaseSetup(RuntimeError):
ConnectionType = TypeVar("ConnectionType", bound=Connection)
CursorType = TypeVar("CursorType", bound=Cursor)
IsolationLevelType = TypeVar("IsolationLevelType")
class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCMeta):
class BaseDatabaseEngine(
Generic[ConnectionType, CursorType, IsolationLevelType], metaclass=abc.ABCMeta
):
def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
self.module = module
@ -124,7 +127,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
@abc.abstractmethod
def attempt_to_set_isolation_level(
self, conn: ConnectionType, isolation_level: Optional[int]
self, conn: ConnectionType, isolation_level: Optional[IsolationLevelType]
) -> None:
"""Attempt to set the connections isolation level.

View File

@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
import psycopg2.extensions
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple, Type, cast, Generic
from synapse.storage.engines._base import (
BaseDatabaseEngine,
ConnectionType,
CursorType,
IncorrectDatabaseSetup,
IsolationLevel,
IsolationLevelType,
)
from synapse.storage.types import Cursor
from synapse.storage.types import Cursor, DBAPI2Module
if TYPE_CHECKING:
from synapse.storage.database import LoggingDatabaseConnection
@ -32,19 +33,16 @@ logger = logging.getLogger(__name__)
class PostgresEngine(
BaseDatabaseEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor]
Generic[ConnectionType, CursorType, IsolationLevelType],
BaseDatabaseEngine[ConnectionType, CursorType, IsolationLevelType],
metaclass=abc.ABCMeta,
):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
isolation_level_map: Mapping[int, IsolationLevelType]
default_isolation_level: IsolationLevelType
# Disables passing `bytes` to txn.execute, c.f.
# https://github.com/matrix-org/synapse/issues/6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")
def __init__(self, module: DBAPI2Module, database_config: Mapping[str, Any]):
super().__init__(module, database_config)
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
# Set the statement timeout to 1 hour by default.
# Any query taking more than 1 hour should probably be considered a bug;
@ -57,16 +55,15 @@ class PostgresEngine(
)
self._version: Optional[int] = None # unknown as yet
self.isolation_level_map: Mapping[int, int] = {
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
self.default_isolation_level = (
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config
@abc.abstractmethod
def get_server_version(self, db_conn: ConnectionType) -> int:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
...
@property
def single_threaded(self) -> bool:
return False
@ -80,21 +77,22 @@ class PostgresEngine(
def check_database(
self,
db_conn: psycopg2.extensions.connection,
db_conn: ConnectionType,
allow_outdated_version: bool = False,
) -> None:
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
self._version = self.get_server_version(db_conn)
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 110000:
raise RuntimeError("Synapse requires PostgreSQL 11 or above.")
with db_conn.cursor() as txn:
# psycopg and psycopg2 both support using cursors as context managers.
with db_conn.cursor() as txn: # type: ignore[attr-defined]
txn.execute("SHOW SERVER_ENCODING")
rows = txn.fetchall()
if rows and rows[0][0] != "UTF8":
@ -155,7 +153,8 @@ class PostgresEngine(
return sql.replace("?", "%s")
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
db_conn.set_isolation_level(self.default_isolation_level)
# mypy doesn't realize that ConnectionType matches the Connection protocol.
self.attempt_to_set_isolation_level(db_conn.conn, self.default_isolation_level) # type: ignore[arg-type]
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
@ -169,7 +168,15 @@ class PostgresEngine(
# Abort really long-running statements and turn them into errors.
if self.statement_timeout is not None:
# TODO Avoid a circular import, this needs to be abstracted.
if self.__class__.__name__ == "Psycopg2Engine":
cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
else:
cursor.execute(
sql.SQL("SET statement_timeout TO {}").format(
self.statement_timeout
)
)
cursor.close()
db_conn.commit()
@ -184,16 +191,9 @@ class PostgresEngine(
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return False
def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
return bool(conn.closed)
def is_connection_closed(self, conn: ConnectionType) -> bool:
# Both psycopg and psycopg2 connections have a closed attributed.
return bool(conn.closed) # type: ignore[attr-defined]
def lock_table(self, txn: Cursor, table: str) -> None:
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
@ -216,25 +216,8 @@ class PostgresEngine(
def row_id_name(self) -> str:
return "ctid"
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY
def attempt_to_set_autocommit(
self, conn: psycopg2.extensions.connection, autocommit: bool
) -> None:
return conn.set_session(autocommit=autocommit)
def attempt_to_set_isolation_level(
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level)
@staticmethod
def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None:
def executescript(cursor: CursorType, script: str) -> None:
"""Execute a chunk of SQL containing multiple semicolon-delimited statements.
Psycopg2 seems happy to do this in DBAPI2's `execute()` function.

View File

@ -0,0 +1,89 @@
# Copyright 2022-2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Mapping, Optional, Tuple
import psycopg
import psycopg.errors
import psycopg.sql
from twisted.enterprise.adbapi import Connection as TxConnection
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
logger = logging.getLogger(__name__)
class PsycopgEngine(
# mypy doesn't seem to like that the psycopg Connection and Cursor are Generics.
PostgresEngine[ # type: ignore[type-var]
psycopg.Connection[Tuple], psycopg.Cursor[Tuple], psycopg.IsolationLevel
]
):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(psycopg, database_config) # type: ignore[arg-type]
# psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
# def _disable_bytes_adapter(_: bytes) -> NoReturn:
# raise Exception("Passing bytes to DB is disabled.")
self.isolation_level_map = {
IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE,
}
self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
def get_server_version(self, db_conn: psycopg.Connection) -> int:
return db_conn.info.server_version
def convert_param_style(self, sql: str) -> str:
# if isinstance(sql, psycopg.sql.Composed):
# return sql
return sql.replace("?", "%s")
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg.errors.Error):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.sqlstate in ["40001", "40P01"]
return False
def in_transaction(self, conn: psycopg.Connection) -> bool:
return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE
def attempt_to_set_autocommit(
self, conn: psycopg.Connection, autocommit: bool
) -> None:
# Sometimes this gets called with a Twisted connection instead, unwrap
# it because it doesn't support __setattr__.
if isinstance(conn, TxConnection):
conn = conn._connection
conn.autocommit = autocommit
def attempt_to_set_isolation_level(
self, conn: psycopg.Connection, isolation_level: Optional[int]
) -> None:
if isolation_level is None:
pg_isolation_level = self.default_isolation_level
else:
pg_isolation_level = self.isolation_level_map[isolation_level]
conn.isolation_level = pg_isolation_level

View File

@ -0,0 +1,93 @@
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Mapping, Optional, NoReturn
import psycopg2.extensions
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
logger = logging.getLogger(__name__)
class Psycopg2Engine(
PostgresEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor, int]
):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(psycopg2, database_config)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f.
# https://github.com/matrix-org/synapse/issues/6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_: bytes) -> NoReturn:
raise Exception("Passing bytes to DB is disabled.")
psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.isolation_level_map = {
IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
self.default_isolation_level = (
psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config
def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int:
return db_conn.server_version
def convert_param_style(self, sql: str) -> str:
return sql.replace("?", "%s")
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
# "40001" serialization_failure
# "40P01" deadlock_detected
return error.pgcode in ["40001", "40P01"]
return False
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY
def attempt_to_set_autocommit(
self, conn: psycopg2.extensions.connection, autocommit: bool
) -> None:
return conn.set_session(autocommit=autocommit)
def attempt_to_set_isolation_level(
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level)
@staticmethod
def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None:
"""Execute a chunk of SQL containing multiple semicolon-delimited statements.
Psycopg2 seems happy to do this in DBAPI2's `execute()` function.
For consistency with SQLite, any ongoing transaction is committed before
executing the script in its own transaction. The script transaction is
left open and it is the responsibility of the caller to commit it.
"""
cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}")

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from synapse.storage.database import LoggingDatabaseConnection
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor, int]):
def __init__(self, database_config: Mapping[str, Any]):
super().__init__(sqlite3, database_config)

View File

@ -190,14 +190,7 @@ class DBAPI2Module(Protocol):
def NotSupportedError(self) -> Type[Exception]:
...
# We originally wrote
# def connect(self, *args, **kwargs) -> Connection: ...
# But mypy doesn't seem to like that because sqlite3.connect takes a mandatory
# positional argument. We can't make that part of the signature though, because
# psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
# the following slightly unusual workaround.
@property
def connect(self) -> Callable[..., Connection]:
def connect(self, *args: Any, **kwargs: Any) -> Connection:
...

View File

@ -971,8 +971,12 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
if USE_POSTGRES_FOR_TESTS == "psycopg":
db_type = "psycopg"
else:
db_type = "psycopg2"
database_config = {
"name": "psycopg2",
"name": db_type,
"args": {
"dbname": test_db,
"host": POSTGRES_HOST,
@ -1067,8 +1071,6 @@ def setup_test_homeserver(
# We need to do cleanup on PostgreSQL
def cleanup() -> None:
import psycopg2
# Close all the db pools
database_pool._db_pool.close()
@ -1094,7 +1096,7 @@ def setup_test_homeserver(
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
except db_engine.module.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e),
category=UserWarning,

View File

@ -59,6 +59,9 @@ def setupdb() -> None:
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine
if USE_POSTGRES_FOR_TESTS == "psycopg":
db_engine = create_engine({"name": "psycopg", "args": {}})
else:
db_engine = create_engine({"name": "psycopg2", "args": {}})
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(