From 3a7368d06f05757aa3e8764fa8428280903cb62b Mon Sep 17 00:00:00 2001 From: Yong Tang <yong.tang.github@outlook.com> Date: Tue, 28 Jul 2020 02:13:16 +0000 Subject: [PATCH] Add qint8/qint16/quint8/quint16 support for tf.math.equal/tf/math.not_equal This PR add qint8/qint16/quint8/quint16 support for tf.math.equal/tf/math.not_equal, as was requested in 26069. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- .../core/kernels/cwise_op_equal_to_1.cc | 3 ++- .../core/kernels/cwise_op_not_equal_to_1.cc | 4 ++-- tensorflow/core/ops/math_ops.cc | 5 +---- .../kernel_tests/cwise_ops_binary_test.py | 19 +++++++++++++++++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc index 64cd784af73..d5cf2b1992d 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc @@ -18,7 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); -REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64); +REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64, + qint8, qint16, quint8, quint16); REGISTER_KERNEL_BUILDER( Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"), ApproximateEqualOp<CPUDevice, float>); diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc index 4de69edd21d..744fecbade7 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc @@ -18,8 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); -REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, - uint64); +REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, + uint64, qint8, qint16, quint8, quint16); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 99be4e2fcd8..83fac604caa 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -708,10 +708,7 @@ REGISTER_OP("GreaterEqual").COMPARISON(); .Input("y: T") \ .Output("z: bool") \ .SetIsCommutative() \ - .Attr( \ - "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ - "int64, uint16, uint32, uint64, complex64, " \ - "quint8, qint8, qint32, string, bool, complex128}") \ + .Attr("T: type") \ .Attr("incompatible_shape_error: bool = true") \ .SetShapeFn([](InferenceContext* c) { \ ShapeHandle x = c->input(0); \ diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 50e6c0ad91f..729933e097d 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -991,6 +991,25 @@ class ComparisonOpTest(test.TestCase): [[True, True, True, True, True], [False, False, False, False, False]], values) + def testEqualQuantizeDType(self): + dtypes = [ + dtypes_lib.qint8, + dtypes_lib.qint16, + dtypes_lib.quint8, + dtypes_lib.quint16, + ] + x = np.asarray([0, 1, 2, 3, 4]) + y = np.asarray([0, 1, 2, 3, 4]) + for dtype in dtypes: + xt = x.astype(dtype.as_numpy_dtype) + yt = y.astype(dtype.as_numpy_dtype) + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], [False, False, False, False, False]], + values) + if __name__ == "__main__": test.main()