Add tf.qint32 support for tf.zeros

This PR is part of 26069 where tf.zeros does not support
basic type of tf.qint32 while all other qtypes have been supported
(tf.{qint8|qint16|quint8|quint16} supported).

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2021-01-10 05:01:12 +00:00
parent 8730806d23
commit 6bf3358080
3 changed files with 13 additions and 0 deletions

View File

@ -187,6 +187,7 @@ REGISTER_KERNEL(CPU, quint8);
REGISTER_KERNEL(CPU, quint16);
REGISTER_KERNEL(CPU, qint8);
REGISTER_KERNEL(CPU, qint16);
REGISTER_KERNEL(CPU, qint32);
#undef REGISTER_CPU_KERNEL

View File

@ -109,6 +109,7 @@ DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16);
DEFINE_FILL_CPU(qint8);
DEFINE_FILL_CPU(qint16);
DEFINE_FILL_CPU(qint32);
#undef DEFINE_FILL_CPU

View File

@ -478,6 +478,17 @@ class ZerosTest(test.TestCase):
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
self.assertFalse(np.any(z_value))
@test_util.disable_tfrt("b/169901260")
def testQint32Dtype(self):
dtype = dtypes_lib.qint32
z = array_ops.zeros([2, 3], dtype=dtype)
self.assertEqual(z.dtype, dtype)
self.assertEqual([2, 3], z.get_shape())
# cast to int32 so that it can be compred with numpy
# where [qint|quint][8|16] are not available.
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
self.assertFalse(np.any(z_value))
class ZerosLikeTest(test.TestCase):