Added functionality to allow SqlDataset
to interpret a database column as various numeric types, including several integer types and dtypes.float64
.
PiperOrigin-RevId: 168055827
This commit is contained in:
parent
fa2000a0b0
commit
be1916ce7e
@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase):
|
||||
c = conn.cursor()
|
||||
c.execute("DROP TABLE IF EXISTS students")
|
||||
c.execute("DROP TABLE IF EXISTS people")
|
||||
c.execute("DROP TABLE IF EXISTS townspeople")
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY,"
|
||||
" first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100),"
|
||||
" school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
||||
"grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
|
||||
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
|
||||
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
|
||||
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
||||
"desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
|
||||
"favorite_big_number INTEGER, favorite_negative_number INTEGER, "
|
||||
"favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
|
||||
"account_balance INTEGER, registration_complete INTEGER)")
|
||||
c.executemany(
|
||||
"INSERT INTO students (first_name, last_name, motto, school_id, "
|
||||
"favorite_nonsense_word, grade_level, income, favorite_number) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
|
||||
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
|
||||
-2147483648)])
|
||||
"favorite_nonsense_word, desk_number, income, favorite_number, "
|
||||
"favorite_big_number, favorite_negative_number, "
|
||||
"favorite_medium_sized_number, brownie_points, account_balance, "
|
||||
"registration_complete) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
|
||||
9223372036854775807, -2, 32767, 0, 0, 1),
|
||||
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
|
||||
-2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
|
||||
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
|
||||
c.executemany(
|
||||
"INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
|
||||
"INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
|
||||
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
|
||||
"California")])
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
|
||||
"KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
|
||||
"FLOAT, accolades FLOAT, triumphs FLOAT)")
|
||||
c.executemany(
|
||||
"INSERT INTO townspeople (first_name, last_name, victories, "
|
||||
"accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
|
||||
[("George", "Washington", 20.00,
|
||||
1331241.321342132321324589798264627463827647382647382643874,
|
||||
9007199254740991.0),
|
||||
("John", "Adams", -19.95,
|
||||
1331241321342132321324589798264627463827647382647382643874.0,
|
||||
9007199254740992.0)])
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase):
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.driver_name: "sqlite",
|
||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase):
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.driver_name: "sqlite",
|
||||
self.query:
|
||||
"SELECT students.first_name, state, motto FROM students "
|
||||
"INNER JOIN people "
|
||||
@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase):
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.driver_name: "sqlite",
|
||||
self.query:
|
||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int8` tensor.
|
||||
def testReadResultSetInt8(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int8` tensor.
|
||||
def testReadResultSetInt8NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
|
||||
dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int8` tensor.
|
||||
def testReadResultSetInt8MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT desk_number, favorite_negative_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((9, -2), sess.run(get_next))
|
||||
# Max and min values of int8
|
||||
self.assertEqual((127, -128), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int16` tensor.
|
||||
def testReadResultSetInt16(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int16` tensor.
|
||||
def testReadResultSetInt16NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
|
||||
dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int16` tensor.
|
||||
def testReadResultSetInt16MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int16
|
||||
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||
# Min value of int16
|
||||
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int32` tensor.
|
||||
def testReadResultSetInt32(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, grade_level FROM students "
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 11), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
def testReadResultSetInt32NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int32` tensor.
|
||||
def testReadResultSetInt32MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase):
|
||||
self.query: "SELECT first_name, favorite_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int32
|
||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||
# Min value of int32
|
||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT first_name, favorite_big_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int64
|
||||
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||
# Min value of int64
|
||||
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in a `uint8` tensor.
|
||||
def testReadResultSetUInt8(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
|
||||
# SQLite database table and place them in `uint8` tensors.
|
||||
def testReadResultSetUInt8MinAndMaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, brownie_points FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint8
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
# Max value of uint8
|
||||
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in a `uint16` tensor.
|
||||
def testReadResultSetUInt16(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
|
||||
# SQLite database table and place them in `uint16` tensors.
|
||||
def testReadResultSetUInt16MinAndMaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, account_balance FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint16
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
# Max value of uint16
|
||||
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
|
||||
# SQLite database table and place them as `True` and `False` respectively
|
||||
# in `bool` tensors.
|
||||
def testReadResultSetBool(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT first_name, registration_complete FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
|
||||
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||
def testReadResultSetBoolNotZeroOrOne(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table
|
||||
# and place it in a `float64` tensor.
|
||||
def testReadResultSetFloat64(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT first_name, last_name, victories FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table beyond
|
||||
# the precision of 64-bit IEEE, without throwing an error. Test that
|
||||
# `SqlDataset` identifies such a value as equal to itself.
|
||||
def testReadResultSetFloat64OverlyPrecise(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT first_name, last_name, accolades FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertEqual(
|
||||
(b"George", b"Washington",
|
||||
1331241.321342132321324589798264627463827647382647382643874),
|
||||
sess.run(get_next))
|
||||
self.assertEqual(
|
||||
(b"John", b"Adams",
|
||||
1331241321342132321324589798264627463827647382647382643874.0),
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table,
|
||||
# representing the largest integer representable as a 64-bit IEEE float
|
||||
# such that the previous integer is also representable as a 64-bit IEEE float.
|
||||
# Test that `SqlDataset` can distinguish these two numbers.
|
||||
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query:
|
||||
"SELECT first_name, last_name, triumphs FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
|
||||
sess.run(get_next))
|
||||
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
|
||||
int rc = sqlite3_step(stmt_);
|
||||
if (rc == SQLITE_ROW) {
|
||||
for (int i = 0; i < column_count_; i++) {
|
||||
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
|
||||
// the types that the client specifies.
|
||||
DataType dt = output_types_[i];
|
||||
Tensor tensor(cpu_allocator(), dt, {});
|
||||
FillTensorWithResultSetEntry(dt, i, &tensor);
|
||||
@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
|
||||
tensor->scalar<string>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_INT8: {
|
||||
int8 value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<int8>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_INT16: {
|
||||
int16 value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<int16>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_INT32: {
|
||||
int32 value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<int32>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_INT64: {
|
||||
int64 value = sqlite3_column_int64(stmt_, column_index);
|
||||
tensor->scalar<int64>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_UINT8: {
|
||||
uint8 value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<uint8>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_UINT16: {
|
||||
uint16 value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<uint16>()() = value;
|
||||
break;
|
||||
}
|
||||
case DT_BOOL: {
|
||||
int value = sqlite3_column_int(stmt_, column_index);
|
||||
tensor->scalar<bool>()() = value ? true : false;
|
||||
break;
|
||||
}
|
||||
case DT_DOUBLE: {
|
||||
double value = sqlite3_column_double(stmt_, column_index);
|
||||
tensor->scalar<double>()() = value;
|
||||
break;
|
||||
}
|
||||
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
||||
default: {
|
||||
LOG(FATAL)
|
||||
|
@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel {
|
||||
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
// TODO(b/64276939) Remove this check when we add support for other
|
||||
// tensorflow types.
|
||||
for (const DataType& dt : output_types_) {
|
||||
OP_REQUIRES(
|
||||
ctx, dt == DT_STRING || dt == DT_INT32,
|
||||
errors::InvalidArgument(
|
||||
"Each element of `output_types_` must be DT_STRING or DT_INT32"));
|
||||
OP_REQUIRES(ctx,
|
||||
dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
|
||||
dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
|
||||
dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
|
||||
errors::InvalidArgument(
|
||||
"Each element of `output_types_` must be one of: "
|
||||
"DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
|
||||
"DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
|
||||
}
|
||||
for (const PartialTensorShape& pts : output_shapes_) {
|
||||
OP_REQUIRES(ctx, pts.dims() == 0,
|
||||
|
Loading…
Reference in New Issue
Block a user