diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc index ac66e558d03..64cd784af73 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc @@ -18,6 +18,7 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); +REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64); REGISTER_KERNEL_BUILDER( Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint("T"), ApproximateEqualOp); diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc index f207158b843..4de69edd21d 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc @@ -18,6 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); +REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, + uint64); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 4252044999d..8f4a9c8dca8 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -717,8 +717,8 @@ REGISTER_OP("GreaterEqual").COMPARISON(); .SetIsCommutative() \ .Attr( \ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ - "int64, complex64, quint8, qint8, qint32, string, bool, " \ - "complex128}") \ + "int64, uint16, uint32, uint64, complex64, " \ + "quint8, qint8, qint32, string, bool, complex128}") \ .Attr("incompatible_shape_error: bool = true") \ .SetShapeFn([](InferenceContext* c) { \ ShapeHandle x = c->input(0); \ diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 71006f28f43..4c6a41bf205 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -953,6 +953,44 @@ class ComparisonOpTest(test.TestCase): "Incompatible shapes|Dimensions must be equal"): f(x.astype(t), y.astype(t)) + def testEqualDType(self): + dtypes = [ + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.bool, + ] + x = np.asarray([0, 1, 2, 3, 4]) + y = np.asarray([0, 1, 2, 3, 4]) + for dtype in dtypes: + xt = x.astype(dtype) + yt = y.astype(dtype) + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], [False, False, False, False, False]], + values) + for dtype in [np.complex64, np.complex128]: + xt = x.astype(dtype) + xt -= 1j * xt + yt = y.astype(dtype) + yt -= 1j * yt + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], [False, False, False, False, False]], + values) + if __name__ == "__main__": test.main()