Expand dtype support for Neg
PiperOrigin-RevId: 317237033 Change-Id: I59c5e45d469f7bf704976b66bc122aaac3982b5e
This commit is contained in:
parent
62082d4072
commit
85ad8031f6
@ -6059,11 +6059,11 @@ I.e., \\(y = -x\\).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
@ -6802,7 +6802,8 @@ filegroup(
|
||||
"cwise_op_minimum.cc",
|
||||
"cwise_op_mul_1.cc",
|
||||
"cwise_op_mul_2.cc",
|
||||
"cwise_op_neg.cc",
|
||||
"cwise_op_neg_1.cc",
|
||||
"cwise_op_neg_2.cc",
|
||||
"cwise_op_pow.cc",
|
||||
"cwise_op_real.cc",
|
||||
"cwise_op_reciprocal.cc",
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
|
||||
complex128);
|
||||
DEFINE_UNARY4(neg, int8, int16, int32, int64);
|
||||
DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
|
||||
complex64, int64, complex128, bfloat16);
|
||||
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
||||
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
|
||||
complex64, complex128);
|
||||
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||
bfloat16, complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||
bfloat16, complex64, complex128);
|
||||
#endif
|
||||
} // namespace tensorflow
|
@ -201,12 +201,12 @@ REGISTER_OP("ComplexAbs")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// Declares cwise unary operations signature: 't -> 't
|
||||
#define UNARY() \
|
||||
Input("x: T") \
|
||||
.Output("y: T") \
|
||||
.Attr( \
|
||||
"T: {bfloat16, half, float, double, int32, int64, complex64, " \
|
||||
"complex128}") \
|
||||
#define UNARY() \
|
||||
Input("x: T") \
|
||||
.Output("y: T") \
|
||||
.Attr( \
|
||||
"T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
|
||||
"complex64, complex128}") \
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
|
||||
#define UNARY_REAL() \
|
||||
|
@ -389,16 +389,22 @@ class UnaryOpTest(test.TestCase):
|
||||
2).reshape(1, 3, 2).astype(dtypes_lib.bfloat16.as_numpy_dtype)
|
||||
self._compareCpu(x, np.abs, math_ops.abs)
|
||||
self._compareCpu(x, np.abs, _ABS)
|
||||
self._compareBoth(x, np.negative, math_ops.negative)
|
||||
self._compareBoth(x, np.negative, _NEG)
|
||||
|
||||
def testInt8Basic(self):
|
||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8)
|
||||
self._compareCpu(x, np.abs, math_ops.abs)
|
||||
self._compareCpu(x, np.abs, _ABS)
|
||||
self._compareBoth(x, np.negative, math_ops.negative)
|
||||
self._compareBoth(x, np.negative, _NEG)
|
||||
|
||||
def testInt16Basic(self):
|
||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16)
|
||||
self._compareCpu(x, np.abs, math_ops.abs)
|
||||
self._compareCpu(x, np.abs, _ABS)
|
||||
self._compareBoth(x, np.negative, math_ops.negative)
|
||||
self._compareBoth(x, np.negative, _NEG)
|
||||
|
||||
def testInt32Basic(self):
|
||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user