Enable generated versions of lgamma and digamma.

PiperOrigin-RevId: 356310920
Change-Id: Icb3af5e20b260e5d86f62f3e5b0e3f55f94b3451
This commit is contained in:
Stephan Herhut 2021-02-08 11:24:51 -08:00 committed by TensorFlower Gardener
parent 7516f6ee75
commit ed365f7817
5 changed files with 8 additions and 6 deletions

View File

@ -21,8 +21,7 @@ REGISTER3(UnaryOp, CPU, "Digamma", functor::digamma, float, Eigen::half,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "Digamma", functor::digamma, float, Eigen::half,
double);
#endif

View File

@ -19,7 +19,9 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(digamma, Eigen::half, float, double);
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -19,7 +19,9 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(lgamma, Eigen::half, float, double);
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -20,8 +20,7 @@ namespace tensorflow {
REGISTER3(UnaryOp, CPU, "Lgamma", functor::lgamma, float, Eigen::half, double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "Lgamma", functor::lgamma, float, Eigen::half, double);
#endif
#endif

View File

@ -54,6 +54,7 @@ filegroup(
"gpu_op_conj.cc",
"gpu_op_cos.cc",
"gpu_op_cosh.cc",
"gpu_op_digamma.cc",
"gpu_op_erf.cc",
"gpu_op_erfc.cc",
"gpu_op_floor.cc",
@ -62,6 +63,7 @@ filegroup(
"gpu_op_is_finite.cc",
"gpu_op_is_inf.cc",
"gpu_op_is_nan.cc",
"gpu_op_lgamma.cc",
"gpu_op_log.cc",
"gpu_op_log1p.cc",
"gpu_op_logical_not.cc",
@ -81,10 +83,8 @@ filegroup(
filegroup(
name = "experimental_unary_gpu_kernel_srcs",
srcs = [
"gpu_op_digamma.cc",
"gpu_op_exp.cc",
"gpu_op_expm1.cc",
"gpu_op_lgamma.cc",
"gpu_op_sign.cc",
],
compatible_with = get_compatible_with_cloud(),