Add uint32 and uint64 support with tf.train.batch (#18805)

* Add uint32 and uint64 support with tf.train.batch

This fix tries to address the issue raised in 18586
to have uint32 and uint64 support with tf.train.batch.

This fix add uint32 and uint64 to `CopyElementToSlice`
for the support.

This fix fixes 18586.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test case for uint32 with tf.train.batch

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add uint64 test case

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-04-28 12:04:45 -07:00 committed by drpngx
parent 9f9b511659
commit c6aa3a0624
2 changed files with 24 additions and 0 deletions

View File

@ -134,6 +134,8 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
switch (element.dtype()) {
TF_CALL_ALL_TYPES(HANDLE_TYPE);
TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
TF_CALL_uint32(HANDLE_TYPE);
TF_CALL_uint64(HANDLE_TYPE);
#undef HANDLE_TYPE
default:
return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",

View File

@ -497,6 +497,28 @@ class BatchTest(test_lib.TestCase):
def testOneThreadDict(self):
self._testOneThreadHelper(use_dict=True)
def testUint32DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
batched = inp.batch([values], batch_size=2)
with self.test_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
coord.request_stop()
for thread in threads:
thread.join()
def testUint64DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
batched = inp.batch([values], batch_size=2)
with self.test_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
coord.request_stop()
for thread in threads:
thread.join()
def testOneThreadDynamicPad(self):
with self.test_session() as sess:
batch_size = 10