Method to set statement timeout.

This commit is contained in:
Patrick Cloke 2023-11-15 13:38:57 -05:00
parent 958580cd11
commit 17ba1a9c1e
4 changed files with 26 additions and 15 deletions

View File

@ -755,6 +755,8 @@ class BackgroundUpdater:
# postgres insists on autocommit for the index # postgres insists on autocommit for the index
conn.engine.attempt_to_set_autocommit(conn.conn, True) conn.engine.attempt_to_set_autocommit(conn.conn, True)
assert isinstance(self.db_pool.engine, PostgresEngine)
try: try:
c = conn.cursor() c = conn.cursor()
@ -768,8 +770,7 @@ class BackgroundUpdater:
# override the global statement timeout to avoid accidentally squashing # override the global statement timeout to avoid accidentally squashing
# a long-running index creation process # a long-running index creation process
timeout_sql = "SET SESSION statement_timeout = 0" self.db_pool.engine.set_statement_timeout(c, 0)
c.execute(timeout_sql)
sql = ( sql = (
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s" "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
@ -791,11 +792,11 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
finally: finally:
# mypy ignore - `statement_timeout` is defined on PostgresEngine
# reset the global timeout to the default # reset the global timeout to the default
default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined] if self.db_pool.engine.statement_timeout is not None:
undo_timeout_sql = f"SET statement_timeout = {default_timeout}" self.db_pool.engine.set_statement_timeout(
conn.cursor().execute(undo_timeout_sql) conn.cursor(), self.db_pool.engine.statement_timeout
)
conn.engine.attempt_to_set_autocommit(conn.conn, False) conn.engine.attempt_to_set_autocommit(conn.conn, False)

View File

@ -64,6 +64,11 @@ class PostgresEngine(
""" """
... ...
@abc.abstractmethod
def set_statement_timeout(self, cursor: CursorType, statement_timeout: int) -> None:
"""Configure the current cursor's statement timeout."""
...
@property @property
def single_threaded(self) -> bool: def single_threaded(self) -> bool:
return False return False
@ -168,15 +173,7 @@ class PostgresEngine(
# Abort really long-running statements and turn them into errors. # Abort really long-running statements and turn them into errors.
if self.statement_timeout is not None: if self.statement_timeout is not None:
# TODO Avoid a circular import, this needs to be abstracted. self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type]
if self.__class__.__name__ == "Psycopg2Engine":
cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
else:
cursor.execute(
sql.SQL("SET statement_timeout TO {}").format(
self.statement_timeout
)
)
cursor.close() cursor.close()
db_conn.commit() db_conn.commit()

View File

@ -52,6 +52,14 @@ class PsycopgEngine(
def get_server_version(self, db_conn: psycopg.Connection) -> int: def get_server_version(self, db_conn: psycopg.Connection) -> int:
return db_conn.info.server_version return db_conn.info.server_version
def set_statement_timeout(
self, cursor: psycopg.Cursor, statement_timeout: int
) -> None:
"""Configure the current cursor's statement timeout."""
cursor.execute(
psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout)
)
def convert_param_style(self, sql: str) -> str: def convert_param_style(self, sql: str) -> str:
# if isinstance(sql, psycopg.sql.Composed): # if isinstance(sql, psycopg.sql.Composed):
# return sql # return sql

View File

@ -51,6 +51,11 @@ class Psycopg2Engine(
def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int: def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int:
return db_conn.server_version return db_conn.server_version
def set_statement_timeout(
self, cursor: psycopg2.extensions.cursor, statement_timeout: int
) -> None:
cursor.execute("SET statement_timeout TO ?", (statement_timeout,))
def convert_param_style(self, sql: str) -> str: def convert_param_style(self, sql: str) -> str:
return sql.replace("?", "%s") return sql.replace("?", "%s")