From 04849fd02a2e1cf0e72c0aa5647c76b7b9d7629e Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Mon, 1 Feb 2021 13:50:28 -0800 Subject: [PATCH] Enable the generated version of rsqrt. PiperOrigin-RevId: 355018568 Change-Id: I434d708740030579db4323b4ba0bda0d953167f4 --- tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc | 2 ++ tensorflow/core/kernels/cwise_op_rsqrt.cc | 3 +-- tensorflow/core/kernels/mlir_generated/BUILD | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc index 5c243cff294..dfeaaa68573 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc @@ -20,7 +20,9 @@ limitations under the License. namespace tensorflow { namespace functor { +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) DEFINE_UNARY3(rsqrt, Eigen::half, float, double); +#endif DEFINE_SIMPLE_BINARY3(rsqrt_grad, Eigen::half, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc index 7d673887a31..4741b7eb438 100644 --- a/tensorflow/core/kernels/cwise_op_rsqrt.cc +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -20,8 +20,7 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, complex64, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ - !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double); #endif #endif diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 41414446380..8f934c7a89b 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -57,6 +57,7 @@ filegroup( "gpu_op_logical_not.cc", "gpu_op_neg.cc", "gpu_op_real.cc", + "gpu_op_rsqrt.cc", "gpu_op_sqrt.cc", "gpu_op_square.cc", "gpu_op_tan.cc", @@ -78,7 +79,6 @@ filegroup( "gpu_op_is_inf.cc", "gpu_op_is_nan.cc", "gpu_op_lgamma.cc", - "gpu_op_rsqrt.cc", "gpu_op_sign.cc", "gpu_op_sin.cc", "gpu_op_sinh.cc",