From 6bf335808079d7be6268aa850f459797984491de Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 10 Jan 2021 05:01:12 +0000 Subject: [PATCH] Add tf.qint32 support for tf.zeros This PR is part of 26069 where tf.zeros does not support basic type of tf.qint32 while all other qtypes have been supported (tf.{qint8|qint16|quint8|quint16} supported). Signed-off-by: Yong Tang --- tensorflow/core/kernels/constant_op.cc | 1 + tensorflow/core/kernels/fill_functor.cc | 1 + tensorflow/python/kernel_tests/constant_op_test.py | 11 +++++++++++ 3 files changed, 13 insertions(+) diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index f9b382ca6f0..0e1de2e6921 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -187,6 +187,7 @@ REGISTER_KERNEL(CPU, quint8); REGISTER_KERNEL(CPU, quint16); REGISTER_KERNEL(CPU, qint8); REGISTER_KERNEL(CPU, qint16); +REGISTER_KERNEL(CPU, qint32); #undef REGISTER_CPU_KERNEL diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index 140497b06d0..4fddd1e1413 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -109,6 +109,7 @@ DEFINE_FILL_CPU(quint8); DEFINE_FILL_CPU(quint16); DEFINE_FILL_CPU(qint8); DEFINE_FILL_CPU(qint16); +DEFINE_FILL_CPU(qint32); #undef DEFINE_FILL_CPU diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index e965c52ee29..cb014fcedd4 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -478,6 +478,17 @@ class ZerosTest(test.TestCase): z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32)) self.assertFalse(np.any(z_value)) + @test_util.disable_tfrt("b/169901260") + def testQint32Dtype(self): + dtype = dtypes_lib.qint32 + z = array_ops.zeros([2, 3], dtype=dtype) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) + # cast to int32 so that it can be compred with numpy + # where [qint|quint][8|16] are not available. + z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32)) + self.assertFalse(np.any(z_value)) + class ZerosLikeTest(test.TestCase):