From 85ad8031f60536361de71dd689c9d88848fefed6 Mon Sep 17 00:00:00 2001
From: Gaurav Jain <gjn@google.com>
Date: Thu, 18 Jun 2020 20:27:03 -0700
Subject: [PATCH] Expand dtype support for Neg

PiperOrigin-RevId: 317237033
Change-Id: I59c5e45d469f7bf704976b66bc122aaac3982b5e
---
 .../mlir/tensorflow/ir/tf_generated_ops.td    |  4 +--
 tensorflow/core/kernels/BUILD                 |  3 ++-
 .../core/kernels/cwise_op_gpu_neg.cu.cc       |  4 +--
 .../{cwise_op_neg.cc => cwise_op_neg_1.cc}    |  6 ++---
 tensorflow/core/kernels/cwise_op_neg_2.cc     | 26 +++++++++++++++++++
 tensorflow/core/ops/math_ops.cc               | 12 ++++-----
 .../kernel_tests/cwise_ops_unary_test.py      |  6 +++++
 7 files changed, 46 insertions(+), 15 deletions(-)
 rename tensorflow/core/kernels/{cwise_op_neg.cc => cwise_op_neg_1.cc} (87%)
 create mode 100644 tensorflow/core/kernels/cwise_op_neg_2.cc

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index dcd083fc398..3b1f3eec699 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -6059,11 +6059,11 @@ I.e., \\(y = -x\\).
   }];
 
   let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
   );
 
   let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ffe2a035591..279dff92c58 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6802,7 +6802,8 @@ filegroup(
         "cwise_op_minimum.cc",
         "cwise_op_mul_1.cc",
         "cwise_op_mul_2.cc",
-        "cwise_op_neg.cc",
+        "cwise_op_neg_1.cc",
+        "cwise_op_neg_2.cc",
         "cwise_op_pow.cc",
         "cwise_op_real.cc",
         "cwise_op_reciprocal.cc",
diff --git a/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc
index ea1ca623560..4f7bb9b2075 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc
@@ -19,8 +19,8 @@ limitations under the License.
 
 namespace tensorflow {
 namespace functor {
-DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
-              complex128);
+DEFINE_UNARY4(neg, int8, int16, int32, int64);
+DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg_1.cc
similarity index 87%
rename from tensorflow/core/kernels/cwise_op_neg.cc
rename to tensorflow/core/kernels/cwise_op_neg_1.cc
index f52cf6c8b91..18a7c61be90 100644
--- a/tensorflow/core/kernels/cwise_op_neg.cc
+++ b/tensorflow/core/kernels/cwise_op_neg_1.cc
@@ -16,8 +16,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
-          complex64, int64, complex128, bfloat16);
+REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
 
 #ifdef TENSORFLOW_USE_SYCL
 REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
@@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
 #endif  // TENSORFLOW_USE_SYCL
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
-          complex64, complex128);
+REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_neg_2.cc b/tensorflow/core/kernels/cwise_op_neg_2.cc
new file mode 100644
index 00000000000..5ea78ad665c
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_neg_2.cc
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
+          bfloat16, complex64, complex128);
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
+          bfloat16, complex64, complex128);
+#endif
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index b81bb9d3afc..2a70f420260 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -201,12 +201,12 @@ REGISTER_OP("ComplexAbs")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 // Declares cwise unary operations signature: 't -> 't
-#define UNARY()                                                          \
-  Input("x: T")                                                          \
-      .Output("y: T")                                                    \
-      .Attr(                                                             \
-          "T: {bfloat16, half, float, double, int32, int64, complex64, " \
-          "complex128}")                                                 \
+#define UNARY()                                                            \
+  Input("x: T")                                                            \
+      .Output("y: T")                                                      \
+      .Attr(                                                               \
+          "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
+          "complex64, complex128}")                                        \
       .SetShapeFn(shape_inference::UnchangedShape)
 
 #define UNARY_REAL()                              \
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
index f4beaabc29a..df848a653d4 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -389,16 +389,22 @@ class UnaryOpTest(test.TestCase):
                   2).reshape(1, 3, 2).astype(dtypes_lib.bfloat16.as_numpy_dtype)
     self._compareCpu(x, np.abs, math_ops.abs)
     self._compareCpu(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
 
   def testInt8Basic(self):
     x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8)
     self._compareCpu(x, np.abs, math_ops.abs)
     self._compareCpu(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
 
   def testInt16Basic(self):
     x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16)
     self._compareCpu(x, np.abs, math_ops.abs)
     self._compareCpu(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
 
   def testInt32Basic(self):
     x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)