Adding ROCm support for the multinomial op
This commit is contained in:
parent
2f4121b1a7
commit
02afe6c1eb
@ -53,7 +53,7 @@ struct MultinomialFunctor {
|
|||||||
typename TTypes<OutputType>::Matrix output);
|
typename TTypes<OutputType>::Matrix output);
|
||||||
};
|
};
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
|
extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
|
||||||
extern template struct MultinomialFunctor<GPUDevice, float, int32>;
|
extern template struct MultinomialFunctor<GPUDevice, float, int32>;
|
||||||
extern template struct MultinomialFunctor<GPUDevice, double, int32>;
|
extern template struct MultinomialFunctor<GPUDevice, double, int32>;
|
||||||
@ -65,7 +65,7 @@ extern template struct MultinomialFunctor<GPUDevice, float, int64>;
|
|||||||
extern template struct MultinomialFunctor<GPUDevice, double, int64>;
|
extern template struct MultinomialFunctor<GPUDevice, double, int64>;
|
||||||
extern template struct MultinomialFunctor<GPUDevice, int32, int64>;
|
extern template struct MultinomialFunctor<GPUDevice, int32, int64>;
|
||||||
extern template struct MultinomialFunctor<GPUDevice, int64, int64>;
|
extern template struct MultinomialFunctor<GPUDevice, int64, int64>;
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
template <typename T, typename OutputType>
|
template <typename T, typename OutputType>
|
||||||
struct MultinomialFunctor<CPUDevice, T, OutputType> {
|
struct MultinomialFunctor<CPUDevice, T, OutputType> {
|
||||||
@ -253,7 +253,7 @@ TF_CALL_float(REGISTER);
|
|||||||
TF_CALL_double(REGISTER);
|
TF_CALL_double(REGISTER);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Multinomial") \
|
REGISTER_KERNEL_BUILDER(Name("Multinomial") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
@ -273,7 +273,7 @@ TF_CALL_float(REGISTER);
|
|||||||
TF_CALL_double(REGISTER);
|
TF_CALL_double(REGISTER);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
template <typename Device, typename T, typename OutputType>
|
template <typename Device, typename T, typename OutputType>
|
||||||
class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
|
class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
|
||||||
@ -321,7 +321,7 @@ TF_CALL_float(REGISTER);
|
|||||||
TF_CALL_double(REGISTER);
|
TF_CALL_double(REGISTER);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
|
REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
@ -343,7 +343,7 @@ TF_CALL_float(REGISTER);
|
|||||||
TF_CALL_double(REGISTER);
|
TF_CALL_double(REGISTER);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -29,6 +29,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
namespace gpuprim = ::cub;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
namespace gpuprim = ::hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -41,12 +47,12 @@ template <typename OutputType>
|
|||||||
__global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
|
__global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
|
||||||
const int32 num_samples, const float* scores,
|
const int32 num_samples, const float* scores,
|
||||||
const float* maxima, OutputType* output) {
|
const float* maxima, OutputType* output) {
|
||||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||||
const int maxima_idx = index / num_classes;
|
const int maxima_idx = index / num_classes;
|
||||||
if (ldg(maxima + maxima_idx) == ldg(scores + index)) {
|
if (ldg(maxima + maxima_idx) == ldg(scores + index)) {
|
||||||
using UnsignedOutputType = typename std::make_unsigned<OutputType>::type;
|
using UnsignedOutputType = typename std::make_unsigned<OutputType>::type;
|
||||||
CudaAtomicMax(reinterpret_cast<UnsignedOutputType*>(output + maxima_idx),
|
GpuAtomicMax(reinterpret_cast<UnsignedOutputType*>(output + maxima_idx),
|
||||||
static_cast<UnsignedOutputType>(index % num_classes));
|
static_cast<UnsignedOutputType>(index % num_classes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,8 +104,9 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
|
|||||||
// Max-reduce along classes for each (batch, sample).
|
// Max-reduce along classes for each (batch, sample).
|
||||||
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
|
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
|
||||||
Constants<GPUDevice> constants;
|
Constants<GPUDevice> constants;
|
||||||
cub::Max op;
|
gpuprim::Max op;
|
||||||
functor::ReduceImpl<float, cub::Max, float*, const float*, ReductionAxes>(
|
functor::ReduceImpl<float, gpuprim::Max, float*, const float*,
|
||||||
|
ReductionAxes>(
|
||||||
/*ctx=*/ctx, /*out=*/maxima.data(), /*in=*/scores.data(), /*in_rank=*/2,
|
/*ctx=*/ctx, /*out=*/maxima.data(), /*in=*/scores.data(), /*in_rank=*/2,
|
||||||
/*in_dim0=*/batch_size * num_samples,
|
/*in_dim0=*/batch_size * num_samples,
|
||||||
/*in_dim1=*/num_classes, /*in_dim2=*/1, /*out_rank=*/1,
|
/*in_dim1=*/num_classes, /*in_dim2=*/1, /*out_rank=*/1,
|
||||||
@ -109,8 +116,8 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
|
|||||||
output.device(d) = output.constant(0LL);
|
output.device(d) = output.constant(0LL);
|
||||||
|
|
||||||
const int32 work_items = batch_size * num_samples * num_classes;
|
const int32 work_items = batch_size * num_samples * num_classes;
|
||||||
GpuLaunchConfig config = GetCudaLaunchConfig(work_items, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(work_items, d);
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
MultinomialKernel<OutputType>, config.block_count,
|
MultinomialKernel<OutputType>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||||
num_classes, num_samples, scores.data(), maxima.data(), output.data()));
|
num_classes, num_samples, scores.data(), maxima.data(), output.data()));
|
||||||
@ -133,4 +140,4 @@ template struct MultinomialFunctor<GPUDevice, int64, int64>;
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user