Merge pull request #46313 from yongtang:26069-tf.zeros-qint32
PiperOrigin-RevId: 351429021 Change-Id: I41a20f65f1d09447cb78e32c0d51dbbcd47b0165
This commit is contained in:
commit
637a24786c
@ -187,6 +187,7 @@ REGISTER_KERNEL(CPU, quint8);
|
||||
REGISTER_KERNEL(CPU, quint16);
|
||||
REGISTER_KERNEL(CPU, qint8);
|
||||
REGISTER_KERNEL(CPU, qint16);
|
||||
REGISTER_KERNEL(CPU, qint32);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
|
||||
|
@ -109,6 +109,7 @@ DEFINE_FILL_CPU(quint8);
|
||||
DEFINE_FILL_CPU(quint16);
|
||||
DEFINE_FILL_CPU(qint8);
|
||||
DEFINE_FILL_CPU(qint16);
|
||||
DEFINE_FILL_CPU(qint32);
|
||||
#undef DEFINE_FILL_CPU
|
||||
|
||||
|
||||
|
@ -478,6 +478,17 @@ class ZerosTest(test.TestCase):
|
||||
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
|
||||
self.assertFalse(np.any(z_value))
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQint32Dtype(self):
|
||||
dtype = dtypes_lib.qint32
|
||||
z = array_ops.zeros([2, 3], dtype=dtype)
|
||||
self.assertEqual(z.dtype, dtype)
|
||||
self.assertEqual([2, 3], z.get_shape())
|
||||
# cast to int32 so that it can be compred with numpy
|
||||
# where [qint|quint][8|16] are not available.
|
||||
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
|
||||
self.assertFalse(np.any(z_value))
|
||||
|
||||
|
||||
class ZerosLikeTest(test.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user