Added functionality to allow SqlDataset
to interpret a database column as various numeric types, including several integer types and dtypes.float64
.
PiperOrigin-RevId: 168055827
This commit is contained in:
parent
fa2000a0b0
commit
be1916ce7e
@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("DROP TABLE IF EXISTS students")
|
c.execute("DROP TABLE IF EXISTS students")
|
||||||
c.execute("DROP TABLE IF EXISTS people")
|
c.execute("DROP TABLE IF EXISTS people")
|
||||||
|
c.execute("DROP TABLE IF EXISTS townspeople")
|
||||||
c.execute(
|
c.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
|
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
|
||||||
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
|
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
|
||||||
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
||||||
"grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
|
"desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
|
||||||
|
"favorite_big_number INTEGER, favorite_negative_number INTEGER, "
|
||||||
|
"favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
|
||||||
|
"account_balance INTEGER, registration_complete INTEGER)")
|
||||||
c.executemany(
|
c.executemany(
|
||||||
"INSERT INTO students (first_name, last_name, motto, school_id, "
|
"INSERT INTO students (first_name, last_name, motto, school_id, "
|
||||||
"favorite_nonsense_word, grade_level, income, favorite_number) "
|
"favorite_nonsense_word, desk_number, income, favorite_number, "
|
||||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
"favorite_big_number, favorite_negative_number, "
|
||||||
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
|
"favorite_medium_sized_number, brownie_points, account_balance, "
|
||||||
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
|
"registration_complete) "
|
||||||
-2147483648)])
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
|
||||||
|
9223372036854775807, -2, 32767, 0, 0, 1),
|
||||||
|
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
|
||||||
|
-2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
|
||||||
c.execute(
|
c.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
|
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
|
||||||
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
|
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
|
||||||
c.executemany(
|
c.executemany(
|
||||||
"INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
|
"INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
|
||||||
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
|
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
|
||||||
"California")])
|
"California")])
|
||||||
|
c.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
|
||||||
|
"KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
|
||||||
|
"FLOAT, accolades FLOAT, triumphs FLOAT)")
|
||||||
|
c.executemany(
|
||||||
|
"INSERT INTO townspeople (first_name, last_name, victories, "
|
||||||
|
"accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[("George", "Washington", 20.00,
|
||||||
|
1331241.321342132321324589798264627463827647382647382643874,
|
||||||
|
9007199254740991.0),
|
||||||
|
("John", "Adams", -19.95,
|
||||||
|
1331241321342132321324589798264627463827647382647382643874.0,
|
||||||
|
9007199254740992.0)])
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query:
|
self.query:
|
||||||
"SELECT students.first_name, state, motto FROM students "
|
"SELECT students.first_name, state, motto FROM students "
|
||||||
"INNER JOIN people "
|
"INNER JOIN people "
|
||||||
@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query:
|
self.query:
|
||||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||||
"FROM students ORDER BY first_name DESC"
|
"FROM students ORDER BY first_name DESC"
|
||||||
@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
|
||||||
|
dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||||
|
"FROM students "
|
||||||
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT desk_number, favorite_negative_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((9, -2), sess.run(get_next))
|
||||||
|
# Max and min values of int8
|
||||||
|
self.assertEqual((127, -128), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
|
||||||
|
dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||||
|
"FROM students "
|
||||||
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||||
|
"FROM students ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Max value of int16
|
||||||
|
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||||
|
# Min value of int16
|
||||||
|
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32(self):
|
def testReadResultSetInt32(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.query: "SELECT first_name, grade_level FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
self.assertEqual((b"Jane", 11), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run(get_next)
|
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32NegativeAndZero(self):
|
def testReadResultSetInt32NegativeAndZero(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32MaxValues(self):
|
def testReadResultSetInt32MaxValues(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
self.query: "SELECT first_name, favorite_number FROM students "
|
self.query: "SELECT first_name, favorite_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
|
# Max value of int32
|
||||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||||
|
# Min value of int32
|
||||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||||
|
# and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, favorite_big_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Max value of int64
|
||||||
|
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||||
|
# Min value of int64
|
||||||
|
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in a `uint8` tensor.
|
||||||
|
def testReadResultSetUInt8(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
|
||||||
|
# SQLite database table and place them in `uint8` tensors.
|
||||||
|
def testReadResultSetUInt8MinAndMaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, brownie_points FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Min value of uint8
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
# Max value of uint8
|
||||||
|
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||||
|
# and place it in a `uint16` tensor.
|
||||||
|
def testReadResultSetUInt16(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
|
||||||
|
# SQLite database table and place them in `uint16` tensors.
|
||||||
|
def testReadResultSetUInt16MinAndMaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, account_balance FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Min value of uint16
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
# Max value of uint16
|
||||||
|
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
|
||||||
|
# SQLite database table and place them as `True` and `False` respectively
|
||||||
|
# in `bool` tensors.
|
||||||
|
def testReadResultSetBool(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, registration_complete FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", True), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
|
||||||
|
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||||
|
def testReadResultSetBoolNotZeroOrOne(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||||
|
"FROM students ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", True), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table
|
||||||
|
# and place it in a `float64` tensor.
|
||||||
|
def testReadResultSetFloat64(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, victories FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||||
|
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table beyond
|
||||||
|
# the precision of 64-bit IEEE, without throwing an error. Test that
|
||||||
|
# `SqlDataset` identifies such a value as equal to itself.
|
||||||
|
def testReadResultSetFloat64OverlyPrecise(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, accolades FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
(b"George", b"Washington",
|
||||||
|
1331241.321342132321324589798264627463827647382647382643874),
|
||||||
|
sess.run(get_next))
|
||||||
|
self.assertEqual(
|
||||||
|
(b"John", b"Adams",
|
||||||
|
1331241321342132321324589798264627463827647382647382643874.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table,
|
||||||
|
# representing the largest integer representable as a 64-bit IEEE float
|
||||||
|
# such that the previous integer is also representable as a 64-bit IEEE float.
|
||||||
|
# Test that `SqlDataset` can distinguish these two numbers.
|
||||||
|
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, triumphs FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
|
|||||||
int rc = sqlite3_step(stmt_);
|
int rc = sqlite3_step(stmt_);
|
||||||
if (rc == SQLITE_ROW) {
|
if (rc == SQLITE_ROW) {
|
||||||
for (int i = 0; i < column_count_; i++) {
|
for (int i = 0; i < column_count_; i++) {
|
||||||
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
|
|
||||||
// the types that the client specifies.
|
|
||||||
DataType dt = output_types_[i];
|
DataType dt = output_types_[i];
|
||||||
Tensor tensor(cpu_allocator(), dt, {});
|
Tensor tensor(cpu_allocator(), dt, {});
|
||||||
FillTensorWithResultSetEntry(dt, i, &tensor);
|
FillTensorWithResultSetEntry(dt, i, &tensor);
|
||||||
@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
|
|||||||
tensor->scalar<string>()() = value;
|
tensor->scalar<string>()() = value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DT_INT8: {
|
||||||
|
int8 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<int8>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_INT16: {
|
||||||
|
int16 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<int16>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case DT_INT32: {
|
case DT_INT32: {
|
||||||
int32 value = sqlite3_column_int(stmt_, column_index);
|
int32 value = sqlite3_column_int(stmt_, column_index);
|
||||||
tensor->scalar<int32>()() = value;
|
tensor->scalar<int32>()() = value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DT_INT64: {
|
||||||
|
int64 value = sqlite3_column_int64(stmt_, column_index);
|
||||||
|
tensor->scalar<int64>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_UINT8: {
|
||||||
|
uint8 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<uint8>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_UINT16: {
|
||||||
|
uint16 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<uint16>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_BOOL: {
|
||||||
|
int value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<bool>()() = value ? true : false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_DOUBLE: {
|
||||||
|
double value = sqlite3_column_double(stmt_, column_index);
|
||||||
|
tensor->scalar<double>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
||||||
default: {
|
default: {
|
||||||
LOG(FATAL)
|
LOG(FATAL)
|
||||||
|
@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel {
|
|||||||
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
// TODO(b/64276939) Remove this check when we add support for other
|
|
||||||
// tensorflow types.
|
|
||||||
for (const DataType& dt : output_types_) {
|
for (const DataType& dt : output_types_) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx,
|
||||||
ctx, dt == DT_STRING || dt == DT_INT32,
|
dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
|
||||||
|
dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
|
||||||
|
dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Each element of `output_types_` must be DT_STRING or DT_INT32"));
|
"Each element of `output_types_` must be one of: "
|
||||||
|
"DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
|
||||||
|
"DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
|
||||||
}
|
}
|
||||||
for (const PartialTensorShape& pts : output_shapes_) {
|
for (const PartialTensorShape& pts : output_shapes_) {
|
||||||
OP_REQUIRES(ctx, pts.dims() == 0,
|
OP_REQUIRES(ctx, pts.dims() == 0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user