diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 43efc63ee45..8d8e5d9896b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3229,8 +3229,8 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -7083,8 +7083,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -14110,4 +14110,4 @@ execution the transfer corresponds to.}]>:$dynamic_key, let results = (outs); TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; -} \ No newline at end of file +} diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc index 41eadd6da6f..af72aa3418c 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc @@ -18,7 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); -REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64); +REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64, + qint8, qint16, quint8, quint16); REGISTER_KERNEL_BUILDER( Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint("T"), ApproximateEqualOp); diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc index f0dbac19bd7..49edd3f3ceb 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc @@ -18,8 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); -REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, - uint64); +REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, + uint64, qint8, qint16, quint8, quint16); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 15d1a1f86e2..d6fde7248ab 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -703,27 +703,24 @@ REGISTER_OP("GreaterEqual").COMPARISON(); // -------------------------------------------------------------------------- -#define EQUALITY_COMPARISON() \ - Input("x: T") \ - .Input("y: T") \ - .Output("z: bool") \ - .SetIsCommutative() \ - .Attr( \ - "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ - "int64, uint16, uint32, uint64, complex64, " \ - "quint8, qint8, qint32, string, bool, complex128}") \ - .Attr("incompatible_shape_error: bool = true") \ - .SetShapeFn([](InferenceContext* c) { \ - ShapeHandle x = c->input(0); \ - ShapeHandle y = c->input(1); \ - ShapeHandle output; \ - bool incompatible_shape_error; \ - TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \ - &incompatible_shape_error)); \ - TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \ - c, x, y, incompatible_shape_error, &output)); \ - c->set_output(0, output); \ - return Status::OK(); \ +#define EQUALITY_COMPARISON() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: bool") \ + .SetIsCommutative() \ + .Attr("T: type") \ + .Attr("incompatible_shape_error: bool = true") \ + .SetShapeFn([](InferenceContext* c) { \ + ShapeHandle x = c->input(0); \ + ShapeHandle y = c->input(1); \ + ShapeHandle output; \ + bool incompatible_shape_error; \ + TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \ + &incompatible_shape_error)); \ + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \ + c, x, y, incompatible_shape_error, &output)); \ + c->set_output(0, output); \ + return Status::OK(); \ }) REGISTER_OP("Equal").EQUALITY_COMPARISON(); diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 30a04ab3d3e..98832dd9885 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -991,6 +991,25 @@ class ComparisonOpTest(test.TestCase): [[True, True, True, True, True], [False, False, False, False, False]], values) + def testEqualQuantizeDType(self): + dtypes = [ + dtypes_lib.qint8, + dtypes_lib.qint16, + dtypes_lib.quint8, + dtypes_lib.quint16, + ] + x = np.asarray([0, 1, 2, 3, 4]) + y = np.asarray([0, 1, 2, 3, 4]) + for dtype in dtypes: + xt = x.astype(dtype.as_numpy_dtype) + yt = y.astype(dtype.as_numpy_dtype) + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], [False, False, False, False, False]], + values) + if __name__ == "__main__": test.main()