[ROCm] enable nextafter op on ROCm.

This commit is contained in:
Wen-Heng (Jack) Chung 2019-08-13 16:32:28 +00:00
parent 86d00f0125
commit 3188e95fe0
3 changed files with 17 additions and 4 deletions

View File

@ -30,8 +30,8 @@ REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(BinaryOp, GPU, "NextAfter", functor::nextafter, float, double);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -25,10 +25,23 @@ namespace functor {
template <typename T>
struct nextafter_op {
EIGEN_EMPTY_STRUCT_CTOR(nextafter_op)
// GPU kernels on ROCm may have issues including standard C++ APIs. Use
// specialized member functions and invoke HIP runtime APIs instead.
#if !TENSORFLOW_USE_ROCM
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x1,
const T& x2) const {
return std::nextafter(x1, x2);
}
#else
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator()(
const float& x1, const float& x2) const {
return nextafterf(x1, x2);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const double operator()(
const double& x1, const double& x2) const {
return nextafter(x1, x2);
}
#endif
};
template <typename T>

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
#include "tensorflow/core/kernels/nextafter_op.h"
@ -26,4 +26,4 @@ DEFINE_BINARY2(nextafter, float, double);
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM