Added functionality to allow SqlDataset to interpret a database column as various numeric types, including several integer types and dtypes.float64.

PiperOrigin-RevId: 168055827
This commit is contained in:
Daniel Grazian 2017-09-08 15:35:35 -07:00 committed by TensorFlower Gardener
parent fa2000a0b0
commit be1916ce7e
3 changed files with 404 additions and 25 deletions

View File

@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase):
c = conn.cursor()
c.execute("DROP TABLE IF EXISTS students")
c.execute("DROP TABLE IF EXISTS people")
c.execute("DROP TABLE IF EXISTS townspeople")
c.execute(
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY,"
" first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100),"
" school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
"grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
"desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
"favorite_big_number INTEGER, favorite_negative_number INTEGER, "
"favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
"account_balance INTEGER, registration_complete INTEGER)")
c.executemany(
"INSERT INTO students (first_name, last_name, motto, school_id, "
"favorite_nonsense_word, grade_level, income, favorite_number) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
-2147483648)])
"favorite_nonsense_word, desk_number, income, favorite_number, "
"favorite_big_number, favorite_negative_number, "
"favorite_medium_sized_number, brownie_points, account_balance, "
"registration_complete) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
9223372036854775807, -2, 32767, 0, 0, 1),
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
-2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
c.execute(
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
c.executemany(
"INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
"INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
"California")])
c.execute(
"CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
"KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
"FLOAT, accolades FLOAT, triumphs FLOAT)")
c.executemany(
"INSERT INTO townspeople (first_name, last_name, victories, "
"accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
[("George", "Washington", 20.00,
1331241.321342132321324589798264627463827647382647382643874,
9007199254740991.0),
("John", "Adams", -19.95,
1331241321342132321324589798264627463827647382647382643874.0,
9007199254740992.0)])
conn.commit()
conn.close()
@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query: "SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query:
"SELECT students.first_name, state, motto FROM students "
"INNER JOIN people "
@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query:
"SELECT first_name, last_name, favorite_nonsense_word "
"FROM students ORDER BY first_name DESC"
@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income, favorite_negative_number "
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT desk_number, favorite_negative_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((9, -2), sess.run(get_next))
# Max and min values of int8
self.assertEqual((127, -128), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income, favorite_negative_number "
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
# Max value of int16
self.assertEqual((b"John", 32767), sess.run(get_next))
# Min value of int16
self.assertEqual((b"Jane", -32768), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, grade_level FROM students "
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 11), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual((b"Jane", 127), sess.run(get_next))
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase):
self.query: "SELECT first_name, favorite_number FROM students "
"ORDER BY first_name DESC"
})
# Max value of int32
self.assertEqual((b"John", 2147483647), sess.run(get_next))
# Min value of int32
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, favorite_big_number FROM students "
"ORDER BY first_name DESC"
})
# Max value of int64
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
# Min value of int64
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, brownie_points FROM students "
"ORDER BY first_name DESC"
})
# Min value of uint8
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint8
self.assertEqual((b"Jane", 255), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, account_balance FROM students "
"ORDER BY first_name DESC"
})
# Min value of uint16
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint16
self.assertEqual((b"Jane", 65535), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, registration_complete FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", False), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", True), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor.
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, victories FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that
# `SqlDataset` identifies such a value as equal to itself.
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, accolades FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual(
(b"George", b"Washington",
1331241.321342132321324589798264627463827647382647382643874),
sess.run(get_next))
self.assertEqual(
(b"John", b"Adams",
1331241321342132321324589798264627463827647382647382643874.0),
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table,
# representing the largest integer representable as a 64-bit IEEE float
# such that the previous integer is also representable as a 64-bit IEEE float.
# Test that `SqlDataset` can distinguish these two numbers.
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, triumphs FROM townspeople "
"ORDER BY first_name"
})
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
sess.run(get_next))
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
int rc = sqlite3_step(stmt_);
if (rc == SQLITE_ROW) {
for (int i = 0; i < column_count_; i++) {
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
// the types that the client specifies.
DataType dt = output_types_[i];
Tensor tensor(cpu_allocator(), dt, {});
FillTensorWithResultSetEntry(dt, i, &tensor);
@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
tensor->scalar<string>()() = value;
break;
}
case DT_INT8: {
int8 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int8>()() = value;
break;
}
case DT_INT16: {
int16 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int16>()() = value;
break;
}
case DT_INT32: {
int32 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int32>()() = value;
break;
}
case DT_INT64: {
int64 value = sqlite3_column_int64(stmt_, column_index);
tensor->scalar<int64>()() = value;
break;
}
case DT_UINT8: {
uint8 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<uint8>()() = value;
break;
}
case DT_UINT16: {
uint16 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<uint16>()() = value;
break;
}
case DT_BOOL: {
int value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<bool>()() = value ? true : false;
break;
}
case DT_DOUBLE: {
double value = sqlite3_column_double(stmt_, column_index);
tensor->scalar<double>()() = value;
break;
}
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
default: {
LOG(FATAL)

View File

@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel {
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
// TODO(b/64276939) Remove this check when we add support for other
// tensorflow types.
for (const DataType& dt : output_types_) {
OP_REQUIRES(
ctx, dt == DT_STRING || dt == DT_INT32,
errors::InvalidArgument(
"Each element of `output_types_` must be DT_STRING or DT_INT32"));
OP_REQUIRES(ctx,
dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
errors::InvalidArgument(
"Each element of `output_types_` must be one of: "
"DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
"DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
}
for (const PartialTensorShape& pts : output_shapes_) {
OP_REQUIRES(ctx, pts.dims() == 0,