Enabling several cwise op kernels for ROCm

This commit is contained in:
Eugene Kuznetsov 2020-01-15 18:27:03 -08:00
parent db8a74a737
commit d45988d245
6 changed files with 12 additions and 33 deletions

View File

@ -29,30 +29,14 @@ REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
#elif TENSORFLOW_USE_ROCM
REGISTER7(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64);
#endif
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
#if GOOGLE_CUDA
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
complex64, complex128);
#elif TENSORFLOW_USE_ROCM
REGISTER3(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double);
#endif
#if GOOGLE_CUDA
REGISTER5(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
double, complex64, complex128);
#elif TENSORFLOW_USE_ROCM
REGISTER3(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
double);
#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
namespace tensorflow {
namespace functor {
#if GOOGLE_CUDA
DEFINE_UNARY2(get_angle, complex64, complex128);
#endif
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -21,12 +21,7 @@ namespace tensorflow {
namespace functor {
DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32,
int64, complex64, complex128);
#if GOOGLE_CUDA
DEFINE_BINARY5(div_no_nan, Eigen::half, float, double, complex64, complex128);
#elif TENSORFLOW_USE_ROCM
// ROCM TODO: fix compiler error for complex64 / complex128 division
DEFINE_BINARY3(div_no_nan, Eigen::half, float, double);
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -44,13 +44,9 @@ REGISTER_KERNEL_BUILDER(Name("Mul")
BinaryOp<CPUDevice, functor::mul<int32>>);
#endif
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER5(BinaryOp, GPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
double, complex64, complex128);
#elif TENSORFLOW_USE_ROCM
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
REGISTER3(BinaryOp, GPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
double);
#endif
#ifdef TENSORFLOW_USE_SYCL

View File

@ -36,6 +36,8 @@ REGISTER_SYCL_KERNEL(complex128);
#if GOOGLE_CUDA
REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
complex64, complex128);
#endif // GOOGLE_CUDA
#elif TENSORFLOW_USE_ROCM
REGISTER3(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double);
#endif
} // namespace tensorflow

View File

@ -147,9 +147,9 @@ BM_BINARY_SCALAR(sycl, Add);
#endif // TENSORFLOW_USE_SYCL
BM_BINARY_SCALAR(cpu, DivNoNan);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
BM_BINARY_SCALAR(gpu, DivNoNan);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
BM_BINARY_SCALAR(sycl, DivNoNan);
#endif // TENSORFLOW_USE_SYCL
@ -204,11 +204,11 @@ Graph* CubeWithMulSquare(int num) {
BM_CUBE(cpu, CubeWithPow3);
BM_CUBE(cpu, CubeWithTwoMuls);
BM_CUBE(cpu, CubeWithMulSquare);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
BM_CUBE(gpu, CubeWithPow3);
BM_CUBE(gpu, CubeWithTwoMuls);
BM_CUBE(gpu, CubeWithMulSquare);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
BM_CUBE(sycl, CubeWithPow3);
BM_CUBE(sycl, CubeWithTwoMuls);