[ROCm] enable nextafter op on ROCm.
This commit is contained in:
parent
86d00f0125
commit
3188e95fe0
@ -30,8 +30,8 @@ REGISTER_SYCL_KERNEL(double);
|
|||||||
#undef REGISTER_SYCL_KERNEL
|
#undef REGISTER_SYCL_KERNEL
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER2(BinaryOp, GPU, "NextAfter", functor::nextafter, float, double);
|
REGISTER2(BinaryOp, GPU, "NextAfter", functor::nextafter, float, double);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -25,10 +25,23 @@ namespace functor {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct nextafter_op {
|
struct nextafter_op {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(nextafter_op)
|
EIGEN_EMPTY_STRUCT_CTOR(nextafter_op)
|
||||||
|
// GPU kernels on ROCm may have issues including standard C++ APIs. Use
|
||||||
|
// specialized member functions and invoke HIP runtime APIs instead.
|
||||||
|
#if !TENSORFLOW_USE_ROCM
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x1,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x1,
|
||||||
const T& x2) const {
|
const T& x2) const {
|
||||||
return std::nextafter(x1, x2);
|
return std::nextafter(x1, x2);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator()(
|
||||||
|
const float& x1, const float& x2) const {
|
||||||
|
return nextafterf(x1, x2);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const double operator()(
|
||||||
|
const double& x1, const double& x2) const {
|
||||||
|
return nextafter(x1, x2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||||
#include "tensorflow/core/kernels/nextafter_op.h"
|
#include "tensorflow/core/kernels/nextafter_op.h"
|
||||||
@ -26,4 +26,4 @@ DEFINE_BINARY2(nextafter, float, double);
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user