From 3a7368d06f05757aa3e8764fa8428280903cb62b Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Tue, 28 Jul 2020 02:13:16 +0000
Subject: [PATCH] Add qint8/qint16/quint8/quint16 support for
 tf.math.equal/tf/math.not_equal

This PR add qint8/qint16/quint8/quint16 support for tf.math.equal/tf/math.not_equal,
as was requested in 26069.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
 .../core/kernels/cwise_op_equal_to_1.cc       |  3 ++-
 .../core/kernels/cwise_op_not_equal_to_1.cc   |  4 ++--
 tensorflow/core/ops/math_ops.cc               |  5 +----
 .../kernel_tests/cwise_ops_binary_test.py     | 19 +++++++++++++++++++
 4 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
index 64cd784af73..d5cf2b1992d 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
@@ -18,7 +18,8 @@ limitations under the License.
 namespace tensorflow {
 REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
           uint8, int8, int16, bfloat16);
-REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
+REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64,
+          qint8, qint16, quint8, quint16);
 REGISTER_KERNEL_BUILDER(
     Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
     ApproximateEqualOp<CPUDevice, float>);
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
index 4de69edd21d..744fecbade7 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
@@ -18,8 +18,8 @@ limitations under the License.
 namespace tensorflow {
 REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
           double, uint8, int8, int16, bfloat16);
-REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
-          uint64);
+REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
+          uint64, qint8, qint16, quint8, quint16);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
           double, uint8);
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 99be4e2fcd8..83fac604caa 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -708,10 +708,7 @@ REGISTER_OP("GreaterEqual").COMPARISON();
       .Input("y: T")                                                       \
       .Output("z: bool")                                                   \
       .SetIsCommutative()                                                  \
-      .Attr(                                                               \
-          "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
-          "int64, uint16, uint32, uint64, complex64, "                     \
-          "quint8, qint8, qint32, string, bool, complex128}")              \
+      .Attr("T: type")                                                     \
       .Attr("incompatible_shape_error: bool = true")                       \
       .SetShapeFn([](InferenceContext* c) {                                \
         ShapeHandle x = c->input(0);                                       \
diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
index 50e6c0ad91f..729933e097d 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
@@ -991,6 +991,25 @@ class ComparisonOpTest(test.TestCase):
           [[True, True, True, True, True], [False, False, False, False, False]],
           values)
 
+  def testEqualQuantizeDType(self):
+    dtypes = [
+        dtypes_lib.qint8,
+        dtypes_lib.qint16,
+        dtypes_lib.quint8,
+        dtypes_lib.quint16,
+    ]
+    x = np.asarray([0, 1, 2, 3, 4])
+    y = np.asarray([0, 1, 2, 3, 4])
+    for dtype in dtypes:
+      xt = x.astype(dtype.as_numpy_dtype)
+      yt = y.astype(dtype.as_numpy_dtype)
+      cmp_eq = math_ops.equal(xt, yt)
+      cmp_ne = math_ops.not_equal(xt, yt)
+      values = self.evaluate([cmp_eq, cmp_ne])
+      self.assertAllEqual(
+          [[True, True, True, True, True], [False, False, False, False, False]],
+          values)
+
 
 if __name__ == "__main__":
   test.main()