Merge pull request #30130 from ROCmSoftwarePlatform:google_upstream_histogram_op
PiperOrigin-RevId: 255170957
This commit is contained in:
commit
107c9e2d1a
tensorflow/core/kernels
@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel {
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
|
||||
.Device(DEVICE_GPU) \
|
||||
@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -13,12 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -28,6 +32,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -66,11 +76,11 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
||||
int num_levels = levels.size();
|
||||
T* d_levels = levels.data();
|
||||
int num_samples = values.size();
|
||||
const cudaStream_t& stream = GetCudaStream(context);
|
||||
const gpuStream_t& stream = GetGpuStream(context);
|
||||
|
||||
// The first HistogramRange is to obtain the temp storage size required
|
||||
// with d_temp_storage = NULL passed to the call.
|
||||
auto err = cub::DeviceHistogram::HistogramRange(
|
||||
auto err = gpuprim::DeviceHistogram::HistogramRange(
|
||||
/* d_temp_storage */ NULL,
|
||||
/* temp_storage_bytes */ temp_storage_bytes,
|
||||
/* d_samples */ d_samples,
|
||||
@ -79,10 +89,10 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
||||
/* d_levels */ d_levels,
|
||||
/* num_samples */ num_samples,
|
||||
/* stream */ stream);
|
||||
if (err != cudaSuccess) {
|
||||
if (err != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"Could not launch HistogramRange to get temp storage: ",
|
||||
cudaGetErrorString(err), ".");
|
||||
GpuGetErrorString(err), ".");
|
||||
}
|
||||
|
||||
Tensor temp_storage;
|
||||
@ -94,7 +104,7 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
||||
|
||||
// The second HistogramRange is to actual run with d_temp_storage
|
||||
// allocated with temp_storage_bytes.
|
||||
err = cub::DeviceHistogram::HistogramRange(
|
||||
err = gpuprim::DeviceHistogram::HistogramRange(
|
||||
/* d_temp_storage */ d_temp_storage,
|
||||
/* temp_storage_bytes */ temp_storage_bytes,
|
||||
/* d_samples */ d_samples,
|
||||
@ -103,9 +113,9 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
||||
/* d_levels */ d_levels,
|
||||
/* num_samples */ num_samples,
|
||||
/* stream */ stream);
|
||||
if (err != cudaSuccess) {
|
||||
if (err != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"Could not launch HistogramRange: ", cudaGetErrorString(err), ".");
|
||||
"Could not launch HistogramRange: ", GpuGetErrorString(err), ".");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
Loading…
Reference in New Issue
Block a user