Add uint32 and uint64 support with tf.train.batch (#18805)
* Add uint32 and uint64 support with tf.train.batch This fix tries to address the issue raised in 18586 to have uint32 and uint64 support with tf.train.batch. This fix add uint32 and uint64 to `CopyElementToSlice` for the support. This fix fixes 18586. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for uint32 with tf.train.batch Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add uint64 test case Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
9f9b511659
commit
c6aa3a0624
@ -134,6 +134,8 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
|
||||
switch (element.dtype()) {
|
||||
TF_CALL_ALL_TYPES(HANDLE_TYPE);
|
||||
TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
|
||||
TF_CALL_uint32(HANDLE_TYPE);
|
||||
TF_CALL_uint64(HANDLE_TYPE);
|
||||
#undef HANDLE_TYPE
|
||||
default:
|
||||
return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
|
||||
|
@ -497,6 +497,28 @@ class BatchTest(test_lib.TestCase):
|
||||
def testOneThreadDict(self):
|
||||
self._testOneThreadHelper(use_dict=True)
|
||||
|
||||
def testUint32DataTypes(self):
|
||||
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
|
||||
batched = inp.batch([values], batch_size=2)
|
||||
with self.test_session() as sess:
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
sess.run(batched)
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def testUint64DataTypes(self):
|
||||
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
|
||||
batched = inp.batch([values], batch_size=2)
|
||||
with self.test_session() as sess:
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
sess.run(batched)
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def testOneThreadDynamicPad(self):
|
||||
with self.test_session() as sess:
|
||||
batch_size = 10
|
||||
|
Loading…
Reference in New Issue
Block a user