From ec9ff389f4c64d31da46b904381087aef0c86796 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Nov 2023 09:34:23 -0500 Subject: [PATCH] More tests for the simple_* methods. (#16596) Expand tests for the simple_* database methods, additionally test against both PostgreSQL and SQLite variants. --- changelog.d/16596.misc | 1 + synapse/storage/database.py | 13 +- tests/storage/test_base.py | 646 +++++++++++++++++++++++++++++++++++- 3 files changed, 633 insertions(+), 27 deletions(-) create mode 100644 changelog.d/16596.misc diff --git a/changelog.d/16596.misc b/changelog.d/16596.misc new file mode 100644 index 0000000000..fa457b12e5 --- /dev/null +++ b/changelog.d/16596.misc @@ -0,0 +1 @@ +Improve tests of the SQL generator. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6d54bb0eb2..abc7d8a5d2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1401,12 +1401,12 @@ class DatabasePool: allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), - f"WHERE {where_clause}" if where_clause else "", + f"WHERE {where_clause} " if where_clause else "", latter, ) txn.execute(sql, list(allvalues.values())) @@ -2062,9 +2062,7 @@ class DatabasePool: where_clause = "" # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ? - sql = f""" - UPDATE {table} SET {set_clause} {where_clause} - """ + sql = f"UPDATE {table} SET {set_clause} {where_clause}" txn.execute_batch(sql, args) @@ -2283,8 +2281,6 @@ class DatabasePool: if not values: return 0 - sql = "DELETE FROM %s" % table - clause, values = make_in_list_sql_clause(txn.database_engine, column, values) clauses = [clause] @@ -2292,8 +2288,7 @@ class DatabasePool: clauses.append("%s = ?" % (key,)) values.append(value) - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses)) txn.execute(sql, values) return txn.rowcount diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index e4a52c301e..b4c490b568 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -14,7 +14,7 @@ from collections import OrderedDict from typing import Generator -from unittest.mock import Mock +from unittest.mock import Mock, call, patch from twisted.internet import defer @@ -24,43 +24,90 @@ from synapse.storage.engines import create_engine from tests import unittest from tests.server import TestHomeServer -from tests.utils import default_config +from tests.utils import USE_POSTGRES_FOR_TESTS, default_config class SQLBaseStoreTestCase(unittest.TestCase): """Test the "simple" SQL generating methods in SQLBaseStore.""" def setUp(self) -> None: - self.db_pool = Mock(spec=["runInteraction"]) + # This is the Twisted connection pool. + conn_pool = Mock(spec=["runInteraction", "runWithConnection"]) self.mock_txn = Mock() - self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) + if USE_POSTGRES_FOR_TESTS: + # To avoid testing psycopg2 itself, patch execute_batch/execute_values + # to assert how it is called. + from psycopg2 import extras + + self.mock_execute_batch = Mock() + self.execute_batch_patcher = patch.object( + extras, "execute_batch", new=self.mock_execute_batch + ) + self.execute_batch_patcher.start() + self.mock_execute_values = Mock() + self.execute_values_patcher = patch.object( + extras, "execute_values", new=self.mock_execute_values + ) + self.execute_values_patcher.start() + + self.mock_conn = Mock( + spec_set=[ + "cursor", + "rollback", + "commit", + "closed", + "reconnect", + "set_session", + "encoding", + ] + ) + self.mock_conn.encoding = "UNICODE" + else: + self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) self.mock_conn.cursor.return_value = self.mock_txn + self.mock_txn.connection = self.mock_conn self.mock_conn.rollback.return_value = None # Our fake runInteraction just runs synchronously inline def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_txn, *args, **kwargs)) - self.db_pool.runInteraction = runInteraction + conn_pool.runInteraction = runInteraction def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_conn, *args, **kwargs)) - self.db_pool.runWithConnection = runWithConnection + conn_pool.runWithConnection = runWithConnection config = default_config(name="test", parse=True) hs = TestHomeServer("test", config=config) - sqlite_config = {"name": "sqlite3"} - engine = create_engine(sqlite_config) + if USE_POSTGRES_FOR_TESTS: + db_config = {"name": "psycopg2", "args": {}} + else: + db_config = {"name": "sqlite3"} + engine = create_engine(db_config) + fake_engine = Mock(wraps=engine) fake_engine.in_transaction.return_value = False + fake_engine.module.OperationalError = engine.module.OperationalError + fake_engine.module.DatabaseError = engine.module.DatabaseError + fake_engine.module.IntegrityError = engine.module.IntegrityError + # Don't convert param style to make assertions easier. + fake_engine.convert_param_style = lambda sql: sql + # To fix isinstance(...) checks. + fake_engine.__class__ = engine.__class__ # type: ignore[assignment] - db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) - db._db_pool = self.db_pool + db = DatabasePool(Mock(), Mock(config=db_config), fake_engine) + db._db_pool = conn_pool self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] + def tearDown(self) -> None: + if USE_POSTGRES_FOR_TESTS: + self.execute_batch_patcher.stop() + self.execute_values_patcher.stop() + @defer.inlineCallbacks def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -71,7 +118,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "INSERT INTO tablename (columname) VALUES(?)", ("Value",) ) @@ -87,10 +134,73 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3) ) + @defer.inlineCallbacks + def test_insert_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert_many( + table="tablename", + keys=( + "col1", + "col2", + ), + values=[ + ( + "val1", + "val2", + ), + ("val3", "val4"), + ], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (col1, col2) VALUES ?", + [("val1", "val2"), ("val3", "val4")], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (col1, col2) VALUES(?, ?)", + [("val1", "val2"), ("val3", "val4")], + ) + + @defer.inlineCallbacks + def test_insert_many_no_iterable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert_many( + table="tablename", + keys=( + "col1", + "col2", + ), + values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (col1, col2) VALUES ?", + [], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (col1, col2) VALUES(?, ?)", [] + ) + @defer.inlineCallbacks def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -103,7 +213,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual("Value", value) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +231,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -156,10 +266,58 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) self.assertEqual([(1,), (2,), (3,)], ret) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) + @defer.inlineCallbacks + def test_select_many_batch( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 3 + self.mock_txn.fetchall.side_effect = [[(1,), (2,)], [(3,)]] + + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_many_batch( + table="tablename", + column="col1", + iterable=("val1", "val2", "val3"), + retcols=("col2",), + keyvalues={"col3": "val4"}, + batch_size=2, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?", + [["val1", "val2"], "val4"], + ), + call( + "SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?", + [["val3"], "val4"], + ), + ], + ) + self.assertEqual([(1,), (2,), (3,)], ret) + + def test_select_many_no_iterable(self) -> None: + self.mock_txn.rowcount = 3 + self.mock_txn.fetchall.side_effect = [(1,), (2,)] + + ret = self.datastore.db_pool.simple_select_many_txn( + self.mock_txn, + table="tablename", + column="col1", + iterable=(), + retcols=("col2",), + keyvalues={"col3": "val4"}, + ) + + self.mock_txn.execute.assert_not_called() + self.assertEqual([], ret) + @defer.inlineCallbacks def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -172,7 +330,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "UPDATE tablename SET columnname = ? WHERE keycol = ?", ["New Value", "TheKey"], ) @@ -191,11 +349,76 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?", [3, 4, 1, 2], ) + @defer.inlineCallbacks + def test_update_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[("val1", "val2")], + value_names=("col3",), + value_values=[("val3",)], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_batch.assert_called_once_with( + self.mock_txn, + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [("val3", "val1", "val2"), ("val3", "val1", "val2")], + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [("val3", "val1", "val2"), ("val3", "val1", "val2")], + ) + + # key_values and value_values must be the same length. + with self.assertRaises(ValueError): + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[("val1", "val2")], + value_names=("col3",), + value_values=[], + desc="", + ) + ) + + @defer.inlineCallbacks + def test_update_many_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_many( + table="tablename", + key_names=("col1", "col2"), + key_values=[], + value_names=("col3",), + value_values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_batch.assert_called_once_with( + self.mock_txn, + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [], + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", + [], + ) + @defer.inlineCallbacks def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 @@ -206,6 +429,393 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.mock_txn.execute.assert_called_with( + self.mock_txn.execute.assert_called_once_with( "DELETE FROM tablename WHERE keycol = ?", ["Go away"] ) + + @defer.inlineCallbacks + def test_delete_many(self) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 2 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=("val1", "val2"), + keyvalues={"col2": "val3"}, + desc="", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "DELETE FROM tablename WHERE col1 = ANY(?) AND col2 = ?", + [["val1", "val2"], "val3"], + ) + self.assertEqual(result, 2) + + @defer.inlineCallbacks + def test_delete_many_no_iterable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=(), + keyvalues={"col2": "val3"}, + desc="", + ) + ) + + self.mock_txn.execute.assert_not_called() + self.assertEqual(result, 0) + + @defer.inlineCallbacks + def test_delete_many_no_keyvalues( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 2 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_many( + table="tablename", + column="col1", + iterable=("val1", "val2"), + keyvalues={}, + desc="", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "DELETE FROM tablename WHERE col1 = ANY(?)", [["val1", "val2"]] + ) + self.assertEqual(result, 2) + + @defer.inlineCallbacks + def test_upsert(self) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING", + ["value"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_with_insertion( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, thirdcol, othercol) VALUES (?, ?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "insertionval", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_with_where( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + where_clause="thirdcol IS NULL", + ) + ) + + self.mock_txn.execute.assert_called_once_with( + "INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) WHERE thirdcol IS NULL DO UPDATE SET othercol=EXCLUDED.othercol", + ["oldvalue", "newvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_many(self) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["keycol1", "keycol2"], + key_values=[["keyval1", "keyval2"], ["keyval3", "keyval4"]], + value_names=["valuecol3"], + value_values=[["val5"], ["val6"]], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES ? ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3", + [("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES (?, ?, ?) ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3", + [("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")], + ) + + @defer.inlineCallbacks + def test_upsert_many_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["columnname"], + key_values=[["oldvalue"]], + value_names=[], + value_values=[], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (columnname) VALUES ? ON CONFLICT (columnname) DO NOTHING", + [("oldvalue",)], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING", + [("oldvalue",)], + ) + + @defer.inlineCallbacks + def test_upsert_emulated_no_values_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.fetchall.return_value = [(1,)] + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call("SELECT 1 FROM tablename WHERE columnname = ?", ["value"]), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "SELECT 1 FROM tablename WHERE columnname = ?", ["value"] + ) + self.assertFalse(result) + + @defer.inlineCallbacks + def test_upsert_emulated_no_values_not_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.fetchall.return_value = [] + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "value"}, + values={}, + insertion_values={"columnname": "value"}, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "SELECT 1 FROM tablename WHERE columnname = ?", + ["value"], + ), + call("INSERT INTO tablename (columnname) VALUES (?)", ["value"]), + ], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_insertion_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_insertion_not_exists( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 0 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + insertion_values={"thirdcol": "insertionval"}, + ) + ) + + self.mock_txn.execute.assert_has_calls( + [ + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ?", + ["newvalue", "oldvalue"], + ), + call( + "INSERT INTO tablename (columnname, othercol, thirdcol) VALUES (?, ?, ?)", + ["oldvalue", "newvalue", "insertionval"], + ), + ] + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_where( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={"othercol": "newvalue"}, + where_clause="thirdcol IS NULL", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL", + ["newvalue", "oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL", + ["newvalue", "oldvalue"], + ) + self.assertTrue(result) + + @defer.inlineCallbacks + def test_upsert_emulated_with_where_no_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename") + + self.mock_txn.rowcount = 1 + + result = yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert( + table="tablename", + keyvalues={"columnname": "oldvalue"}, + values={}, + where_clause="thirdcol IS NULL", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_txn.execute.assert_has_calls( + [ + call("LOCK TABLE tablename in EXCLUSIVE MODE", ()), + call( + "SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL", + ["oldvalue"], + ), + ] + ) + else: + self.mock_txn.execute.assert_called_once_with( + "SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL", + ["oldvalue"], + ) + self.assertFalse(result)