From c6aa3a0624ef7e1ff95cc07dde20c74105c4a584 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 28 Apr 2018 12:04:45 -0700 Subject: [PATCH] Add uint32 and uint64 support with tf.train.batch (#18805) * Add uint32 and uint64 support with tf.train.batch This fix tries to address the issue raised in 18586 to have uint32 and uint64 support with tf.train.batch. This fix add uint32 and uint64 to `CopyElementToSlice` for the support. This fix fixes 18586. Signed-off-by: Yong Tang * Add test case for uint32 with tf.train.batch Signed-off-by: Yong Tang * Add uint64 test case Signed-off-by: Yong Tang --- tensorflow/core/kernels/batch_util.cc | 2 ++ tensorflow/python/training/input_test.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tensorflow/core/kernels/batch_util.cc b/tensorflow/core/kernels/batch_util.cc index 52be1ab8d0f..1182ed42e7a 100644 --- a/tensorflow/core/kernels/batch_util.cc +++ b/tensorflow/core/kernels/batch_util.cc @@ -134,6 +134,8 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { switch (element.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); + TF_CALL_uint32(HANDLE_TYPE); + TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 3a25bfe3432..1b1e89cb26d 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -497,6 +497,28 @@ class BatchTest(test_lib.TestCase): def testOneThreadDict(self): self._testOneThreadHelper(use_dict=True) + def testUint32DataTypes(self): + values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32) + batched = inp.batch([values], batch_size=2) + with self.test_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + sess.run(batched) + coord.request_stop() + for thread in threads: + thread.join() + + def testUint64DataTypes(self): + values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64) + batched = inp.batch([values], batch_size=2) + with self.test_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + sess.run(batched) + coord.request_stop() + for thread in threads: + thread.join() + def testOneThreadDynamicPad(self): with self.test_session() as sess: batch_size = 10