Enable the generated version of rsqrt.

PiperOrigin-RevId: 355018568
Change-Id: I434d708740030579db4323b4ba0bda0d953167f4
This commit is contained in:
Stephan Herhut 2021-02-01 13:50:28 -08:00 committed by TensorFlower Gardener
parent 94f61a1b31
commit 04849fd02a
3 changed files with 4 additions and 3 deletions

View File

@ -20,7 +20,9 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(rsqrt, Eigen::half, float, double);
#endif
DEFINE_SIMPLE_BINARY3(rsqrt_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow

View File

@ -20,8 +20,7 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
#endif
#endif

View File

@ -57,6 +57,7 @@ filegroup(
"gpu_op_logical_not.cc",
"gpu_op_neg.cc",
"gpu_op_real.cc",
"gpu_op_rsqrt.cc",
"gpu_op_sqrt.cc",
"gpu_op_square.cc",
"gpu_op_tan.cc",
@ -78,7 +79,6 @@ filegroup(
"gpu_op_is_inf.cc",
"gpu_op_is_nan.cc",
"gpu_op_lgamma.cc",
"gpu_op_rsqrt.cc",
"gpu_op_sign.cc",
"gpu_op_sin.cc",
"gpu_op_sinh.cc",