Add tf.qint32 support for tf.zeros
This PR is part of 26069 where tf.zeros does not support basic type of tf.qint32 while all other qtypes have been supported (tf.{qint8|qint16|quint8|quint16} supported). Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
8730806d23
commit
6bf3358080
@ -187,6 +187,7 @@ REGISTER_KERNEL(CPU, quint8);
|
||||
REGISTER_KERNEL(CPU, quint16);
|
||||
REGISTER_KERNEL(CPU, qint8);
|
||||
REGISTER_KERNEL(CPU, qint16);
|
||||
REGISTER_KERNEL(CPU, qint32);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
|
||||
|
@ -109,6 +109,7 @@ DEFINE_FILL_CPU(quint8);
|
||||
DEFINE_FILL_CPU(quint16);
|
||||
DEFINE_FILL_CPU(qint8);
|
||||
DEFINE_FILL_CPU(qint16);
|
||||
DEFINE_FILL_CPU(qint32);
|
||||
#undef DEFINE_FILL_CPU
|
||||
|
||||
|
||||
|
@ -478,6 +478,17 @@ class ZerosTest(test.TestCase):
|
||||
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
|
||||
self.assertFalse(np.any(z_value))
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQint32Dtype(self):
|
||||
dtype = dtypes_lib.qint32
|
||||
z = array_ops.zeros([2, 3], dtype=dtype)
|
||||
self.assertEqual(z.dtype, dtype)
|
||||
self.assertEqual([2, 3], z.get_shape())
|
||||
# cast to int32 so that it can be compred with numpy
|
||||
# where [qint|quint][8|16] are not available.
|
||||
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
|
||||
self.assertFalse(np.any(z_value))
|
||||
|
||||
|
||||
class ZerosLikeTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user