From 928ff3e27ba8567b06ada34d85ffaffad1e92c10 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 6 Apr 2020 21:07:14 +0000 Subject: [PATCH 1/4] Add uint16, uint32, uint64 support for tf.math.equal This PR tries to address the issue raised in 26069 where tf.math.equal does not suport basic data types such as uint16, uint32, and uint64. While there might be some restrictions on comparision (e.g. >, <, etc) for certain data types due to CPU or GPU, the comparision of basic data types such as uint16, uint32, uint64 are very much simple operation across the board. They are important in many ops as well. For that reason, it makes sense to make sure at least all basic data types support `equal`. This PR adds the missing uint16, uint32, uint64 support for tf.math.equal Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_equal_to_1.cc | 1 + tensorflow/core/ops/math_ops.cc | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc index ac66e558d03..64cd784af73 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc @@ -18,6 +18,7 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); +REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64); REGISTER_KERNEL_BUILDER( Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint("T"), ApproximateEqualOp); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index e441c73cc57..232a23535de 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -717,8 +717,8 @@ REGISTER_OP("GreaterEqual").COMPARISON(); .SetIsCommutative() \ .Attr( \ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ - "int64, complex64, quint8, qint8, qint32, string, bool, " \ - "complex128}") \ + "int64, uint16, uint32, uint64, complex64, " \ + "quint8, qint8, qint32, string, bool, complex128}") \ .Attr("incompatible_shape_error: bool = true") \ .SetShapeFn([](InferenceContext* c) { \ ShapeHandle x = c->input(0); \ From 67ef5139174c5ccfceae9219a254cdd94a2a9672 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 6 Apr 2020 22:52:43 +0000 Subject: [PATCH 2/4] Add uint16, uint32, uint64 support for tf.math.not_equal Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_not_equal_to_1.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc index f207158b843..0b518d911b0 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc @@ -18,6 +18,7 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8, int8, int16, bfloat16); +REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32, uint64); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8); From 42b0014216ab04f704967b722f7062df8a4180e1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 6 Apr 2020 22:53:09 +0000 Subject: [PATCH 3/4] Add test cases of uint16, uint32, uint64 support for tf.math.[equal|not_equal] Signed-off-by: Yong Tang --- .../kernel_tests/cwise_ops_binary_test.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 153d0e78fd7..1f4b570027f 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -948,6 +948,31 @@ class ComparisonOpTest(test.TestCase): "Incompatible shapes|Dimensions must be equal"): f(x.astype(t), y.astype(t)) + def testEqualDType(self): + dtypes = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64] + x = np.asarray([0, 1, 2, 3, 4]) + y = np.asarray([0, 1, 2, 3, 4]) + for dtype in dtypes: + xt = x.astype(dtype) + yt = y.astype(dtype) + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], + [False, False, False, False, False]], values) + for dtype in [np.complex64, np.complex128]: + xt = x.astype(dtype) + xt -= 1j * xt + yt = y.astype(dtype) + yt -= 1j * yt + cmp_eq = math_ops.equal(xt, yt) + cmp_ne = math_ops.not_equal(xt, yt) + values = self.evaluate([cmp_eq, cmp_ne]) + self.assertAllEqual( + [[True, True, True, True, True], + [False, False, False, False, False]], values) + if __name__ == "__main__": test.main() From 9319da22772b62e3e914bbdbada7a4a5d106c636 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 6 Apr 2020 23:59:44 +0000 Subject: [PATCH 4/4] Pylint fix Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/cwise_ops_binary_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 1f4b570027f..67fc8519323 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -949,7 +949,11 @@ class ComparisonOpTest(test.TestCase): f(x.astype(t), y.astype(t)) def testEqualDType(self): - dtypes = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64] + dtypes = [ + np.float16, np.float32, np.float64, + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + ] x = np.asarray([0, 1, 2, 3, 4]) y = np.asarray([0, 1, 2, 3, 4]) for dtype in dtypes: