Enable the generated version of complex and complex_abs.
PiperOrigin-RevId: 354903732 Change-Id: Ic4374567cc361a30f9d2b870adc5e8fdfb38a350
This commit is contained in:
parent
b1952612df
commit
9144fd2b57
@ -21,18 +21,14 @@ REGISTER8(UnaryOp, CPU, "Abs", functor::abs, Eigen::half, bfloat16, float,
|
||||
REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, Eigen::half, float, double, int64);
|
||||
#endif
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER2(UnaryOp, GPU, "ComplexAbs", functor::abs, complex64, complex128);
|
||||
#endif
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
REGISTER_KERNEL_BUILDER(Name("Abs")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("x")
|
||||
@ -40,6 +36,5 @@ REGISTER_KERNEL_BUILDER(Name("Abs")
|
||||
.TypeConstraint<int32>("T"),
|
||||
UnaryOp<CPUDevice, functor::abs<int32>>);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,8 +27,7 @@ REGISTER_COMPLEX(CPU, float, complex64);
|
||||
REGISTER_COMPLEX(CPU, double, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
REGISTER_COMPLEX(GPU, float, complex64);
|
||||
REGISTER_COMPLEX(GPU, double, complex128);
|
||||
#endif
|
||||
|
@ -19,9 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#ifdef MLIR_GENERATED_GPU_KERNELS_ENABLED
|
||||
DEFINE_UNARY2(abs, complex64, complex128);
|
||||
#else
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_UNARY6(abs, Eigen::half, float, double, int64, complex64, complex128);
|
||||
#endif
|
||||
} // namespace functor
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_BINARY2(make_complex, float, double);
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -42,6 +42,8 @@ filegroup(
|
||||
"gpu_op_atan.cc",
|
||||
"gpu_op_atanh.cc",
|
||||
"gpu_op_ceil.cc",
|
||||
"gpu_op_complex.cc",
|
||||
"gpu_op_complex_abs.cc",
|
||||
"gpu_op_cos.cc",
|
||||
"gpu_op_cosh.cc",
|
||||
"gpu_op_sqrt.cc",
|
||||
@ -59,8 +61,6 @@ filegroup(
|
||||
"gpu_op_acosh.cc",
|
||||
"gpu_op_asin.cc",
|
||||
"gpu_op_asinh.cc",
|
||||
"gpu_op_complex.cc",
|
||||
"gpu_op_complex_abs.cc",
|
||||
"gpu_op_conj.cc",
|
||||
"gpu_op_erf.cc",
|
||||
"gpu_op_erfc.cc",
|
||||
|
Loading…
Reference in New Issue
Block a user