Expand dtype support for Neg
PiperOrigin-RevId: 317237033 Change-Id: I59c5e45d469f7bf704976b66bc122aaac3982b5e
This commit is contained in:
parent
62082d4072
commit
85ad8031f6
tensorflow
compiler/mlir/tensorflow/ir
core
python/kernel_tests
@ -6059,11 +6059,11 @@ I.e., \\(y = -x\\).
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
|
||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
@ -6802,7 +6802,8 @@ filegroup(
|
|||||||
"cwise_op_minimum.cc",
|
"cwise_op_minimum.cc",
|
||||||
"cwise_op_mul_1.cc",
|
"cwise_op_mul_1.cc",
|
||||||
"cwise_op_mul_2.cc",
|
"cwise_op_mul_2.cc",
|
||||||
"cwise_op_neg.cc",
|
"cwise_op_neg_1.cc",
|
||||||
|
"cwise_op_neg_2.cc",
|
||||||
"cwise_op_pow.cc",
|
"cwise_op_pow.cc",
|
||||||
"cwise_op_real.cc",
|
"cwise_op_real.cc",
|
||||||
"cwise_op_reciprocal.cc",
|
"cwise_op_reciprocal.cc",
|
||||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
|
DEFINE_UNARY4(neg, int8, int16, int32, int64);
|
||||||
complex128);
|
DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
|
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
|
||||||
complex64, int64, complex128, bfloat16);
|
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
||||||
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
|
|||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
|
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
|
||||||
complex64, complex128);
|
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||||
|
bfloat16, complex64, complex128);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||||
|
bfloat16, complex64, complex128);
|
||||||
|
#endif
|
||||||
|
} // namespace tensorflow
|
@ -205,8 +205,8 @@ REGISTER_OP("ComplexAbs")
|
|||||||
Input("x: T") \
|
Input("x: T") \
|
||||||
.Output("y: T") \
|
.Output("y: T") \
|
||||||
.Attr( \
|
.Attr( \
|
||||||
"T: {bfloat16, half, float, double, int32, int64, complex64, " \
|
"T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
|
||||||
"complex128}") \
|
"complex64, complex128}") \
|
||||||
.SetShapeFn(shape_inference::UnchangedShape)
|
.SetShapeFn(shape_inference::UnchangedShape)
|
||||||
|
|
||||||
#define UNARY_REAL() \
|
#define UNARY_REAL() \
|
||||||
|
@ -389,16 +389,22 @@ class UnaryOpTest(test.TestCase):
|
|||||||
2).reshape(1, 3, 2).astype(dtypes_lib.bfloat16.as_numpy_dtype)
|
2).reshape(1, 3, 2).astype(dtypes_lib.bfloat16.as_numpy_dtype)
|
||||||
self._compareCpu(x, np.abs, math_ops.abs)
|
self._compareCpu(x, np.abs, math_ops.abs)
|
||||||
self._compareCpu(x, np.abs, _ABS)
|
self._compareCpu(x, np.abs, _ABS)
|
||||||
|
self._compareBoth(x, np.negative, math_ops.negative)
|
||||||
|
self._compareBoth(x, np.negative, _NEG)
|
||||||
|
|
||||||
def testInt8Basic(self):
|
def testInt8Basic(self):
|
||||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8)
|
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8)
|
||||||
self._compareCpu(x, np.abs, math_ops.abs)
|
self._compareCpu(x, np.abs, math_ops.abs)
|
||||||
self._compareCpu(x, np.abs, _ABS)
|
self._compareCpu(x, np.abs, _ABS)
|
||||||
|
self._compareBoth(x, np.negative, math_ops.negative)
|
||||||
|
self._compareBoth(x, np.negative, _NEG)
|
||||||
|
|
||||||
def testInt16Basic(self):
|
def testInt16Basic(self):
|
||||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16)
|
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16)
|
||||||
self._compareCpu(x, np.abs, math_ops.abs)
|
self._compareCpu(x, np.abs, math_ops.abs)
|
||||||
self._compareCpu(x, np.abs, _ABS)
|
self._compareCpu(x, np.abs, _ABS)
|
||||||
|
self._compareBoth(x, np.negative, math_ops.negative)
|
||||||
|
self._compareBoth(x, np.negative, _NEG)
|
||||||
|
|
||||||
def testInt32Basic(self):
|
def testInt32Basic(self):
|
||||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
|
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
|
||||||
|
Loading…
Reference in New Issue
Block a user