Merge pull request #30130 from ROCmSoftwarePlatform:google_upstream_histogram_op
PiperOrigin-RevId: 255170957
This commit is contained in:
commit
107c9e2d1a
@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel {
|
|||||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define REGISTER_KERNELS(type) \
|
#define REGISTER_KERNELS(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
|
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
|||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -13,12 +13,16 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/cub/device/device_histogram.cuh"
|
#include "third_party/cub/device/device_histogram.cuh"
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
|
||||||
|
#endif
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -28,6 +32,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
namespace gpuprim = ::cub;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
namespace gpuprim = ::hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
@ -66,11 +76,11 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
|||||||
int num_levels = levels.size();
|
int num_levels = levels.size();
|
||||||
T* d_levels = levels.data();
|
T* d_levels = levels.data();
|
||||||
int num_samples = values.size();
|
int num_samples = values.size();
|
||||||
const cudaStream_t& stream = GetCudaStream(context);
|
const gpuStream_t& stream = GetGpuStream(context);
|
||||||
|
|
||||||
// The first HistogramRange is to obtain the temp storage size required
|
// The first HistogramRange is to obtain the temp storage size required
|
||||||
// with d_temp_storage = NULL passed to the call.
|
// with d_temp_storage = NULL passed to the call.
|
||||||
auto err = cub::DeviceHistogram::HistogramRange(
|
auto err = gpuprim::DeviceHistogram::HistogramRange(
|
||||||
/* d_temp_storage */ NULL,
|
/* d_temp_storage */ NULL,
|
||||||
/* temp_storage_bytes */ temp_storage_bytes,
|
/* temp_storage_bytes */ temp_storage_bytes,
|
||||||
/* d_samples */ d_samples,
|
/* d_samples */ d_samples,
|
||||||
@ -79,10 +89,10 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
|||||||
/* d_levels */ d_levels,
|
/* d_levels */ d_levels,
|
||||||
/* num_samples */ num_samples,
|
/* num_samples */ num_samples,
|
||||||
/* stream */ stream);
|
/* stream */ stream);
|
||||||
if (err != cudaSuccess) {
|
if (err != gpuSuccess) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Could not launch HistogramRange to get temp storage: ",
|
"Could not launch HistogramRange to get temp storage: ",
|
||||||
cudaGetErrorString(err), ".");
|
GpuGetErrorString(err), ".");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor temp_storage;
|
Tensor temp_storage;
|
||||||
@ -94,7 +104,7 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
|||||||
|
|
||||||
// The second HistogramRange is to actual run with d_temp_storage
|
// The second HistogramRange is to actual run with d_temp_storage
|
||||||
// allocated with temp_storage_bytes.
|
// allocated with temp_storage_bytes.
|
||||||
err = cub::DeviceHistogram::HistogramRange(
|
err = gpuprim::DeviceHistogram::HistogramRange(
|
||||||
/* d_temp_storage */ d_temp_storage,
|
/* d_temp_storage */ d_temp_storage,
|
||||||
/* temp_storage_bytes */ temp_storage_bytes,
|
/* temp_storage_bytes */ temp_storage_bytes,
|
||||||
/* d_samples */ d_samples,
|
/* d_samples */ d_samples,
|
||||||
@ -103,9 +113,9 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
|
|||||||
/* d_levels */ d_levels,
|
/* d_levels */ d_levels,
|
||||||
/* num_samples */ num_samples,
|
/* num_samples */ num_samples,
|
||||||
/* stream */ stream);
|
/* stream */ stream);
|
||||||
if (err != cudaSuccess) {
|
if (err != gpuSuccess) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Could not launch HistogramRange: ", cudaGetErrorString(err), ".");
|
"Could not launch HistogramRange: ", GpuGetErrorString(err), ".");
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user