Merge pull request from ROCmSoftwarePlatform:google_upstream_histogram_op

PiperOrigin-RevId: 255170957
This commit is contained in:
TensorFlower Gardener 2019-06-26 11:06:56 -07:00
commit 107c9e2d1a
2 changed files with 20 additions and 10 deletions

View File

@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
.Device(DEVICE_GPU) \
@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // end namespace tensorflow

View File

@ -13,12 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/cub/device/device_histogram.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -28,6 +32,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -66,11 +76,11 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
int num_levels = levels.size();
T* d_levels = levels.data();
int num_samples = values.size();
const cudaStream_t& stream = GetCudaStream(context);
const gpuStream_t& stream = GetGpuStream(context);
// The first HistogramRange is to obtain the temp storage size required
// with d_temp_storage = NULL passed to the call.
auto err = cub::DeviceHistogram::HistogramRange(
auto err = gpuprim::DeviceHistogram::HistogramRange(
/* d_temp_storage */ NULL,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
@ -79,10 +89,10 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
/* d_levels */ d_levels,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramRange to get temp storage: ",
cudaGetErrorString(err), ".");
GpuGetErrorString(err), ".");
}
Tensor temp_storage;
@ -94,7 +104,7 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
// The second HistogramRange is to actual run with d_temp_storage
// allocated with temp_storage_bytes.
err = cub::DeviceHistogram::HistogramRange(
err = gpuprim::DeviceHistogram::HistogramRange(
/* d_temp_storage */ d_temp_storage,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
@ -103,9 +113,9 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
/* d_levels */ d_levels,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramRange: ", cudaGetErrorString(err), ".");
"Could not launch HistogramRange: ", GpuGetErrorString(err), ".");
}
return Status::OK();