diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 153d0e78fd7..1f4b570027f 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -948,6 +948,31 @@ class ComparisonOpTest(test.TestCase): "Incompatible shapes|Dimensions must be equal"): f(x.astype(t), y.astype(t)) + def testEqualDType(self): + dtypes = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64] + x = np.asarray([0, 1, 2, 3, 4]) + y = np.asarray([0, 1, 2, 3, 4]) + for dtype in dtypes: + xt = x.astype(dtype) + yt = y.astype(dtype) + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], + [False, False, False, False, False]], values) + for dtype in [np.complex64, np.complex128]: + xt = x.astype(dtype) + xt -= 1j * xt + yt = y.astype(dtype) + yt -= 1j * yt + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], + [False, False, False, False, False]], values) + if __name__ == "__main__": test.main()