Merge pull request #46313 from yongtang:26069-tf.zeros-qint32

PiperOrigin-RevId: 351429021
Change-Id: I41a20f65f1d09447cb78e32c0d51dbbcd47b0165
This commit is contained in:
TensorFlower Gardener 2021-01-12 13:45:02 -08:00
commit 637a24786c
3 changed files with 13 additions and 0 deletions

View File

@ -187,6 +187,7 @@ REGISTER_KERNEL(CPU, quint8);
REGISTER_KERNEL(CPU, quint16);
REGISTER_KERNEL(CPU, qint8);
REGISTER_KERNEL(CPU, qint16);
REGISTER_KERNEL(CPU, qint32);
#undef REGISTER_CPU_KERNEL

View File

@ -109,6 +109,7 @@ DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16);
DEFINE_FILL_CPU(qint8);
DEFINE_FILL_CPU(qint16);
DEFINE_FILL_CPU(qint32);
#undef DEFINE_FILL_CPU

View File

@ -478,6 +478,17 @@ class ZerosTest(test.TestCase):
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
self.assertFalse(np.any(z_value))
@test_util.disable_tfrt("b/169901260")
def testQint32Dtype(self):
dtype = dtypes_lib.qint32
z = array_ops.zeros([2, 3], dtype=dtype)
self.assertEqual(z.dtype, dtype)
self.assertEqual([2, 3], z.get_shape())
# cast to int32 so that it can be compred with numpy
# where [qint|quint][8|16] are not available.
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
self.assertFalse(np.any(z_value))
class ZerosLikeTest(test.TestCase):