Adding ROCm support for the LRN op

This commit is contained in:
Deven Desai 2019-07-10 01:11:24 +00:00
parent ec71eb1ac4
commit 3c2220abe7
2 changed files with 184 additions and 10 deletions
tensorflow/core/kernels

View File

@ -4366,7 +4366,7 @@ tf_kernel_library(
tf_kernel_library(
name = "lrn_op",
prefix = "lrn_op",
deps = NN_DEPS,
deps = NN_DEPS + if_rocm([":conv_ops_gpu_hdrs"]),
)
tf_kernel_library(

View File

@ -36,9 +36,17 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#if TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#endif
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
@ -164,7 +172,7 @@ struct LaunchLRN<CPUDevice, T> {
T beta_;
};
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
struct LaunchLRN<GPUDevice, T> {
@ -173,6 +181,7 @@ struct LaunchLRN<GPUDevice, T> {
void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
Tensor* output) {
#if GOOGLE_CUDA
OP_REQUIRES(
context, beta_ >= 0.01,
errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
@ -217,6 +226,71 @@ struct LaunchLRN<GPUDevice, T> {
.ok();
OP_REQUIRES(context, status,
errors::Internal("NormalizeWithDimensions launch failed"));
#elif TENSORFLOW_USE_ROCM
// For NHWC input/output tensors, convert to NCHW because it's the only
// supported format in MIOpen for now.
// Cast to platform-specific int to avoid conversion warnings.
const int batch = static_cast<int>(in.dim_size(0));
const int rows = static_cast<int>(in.dim_size(1));
const int cols = static_cast<int>(in.dim_size(2));
const int depth = static_cast<int>(in.dim_size(3));
Tensor transformed_input;
OP_REQUIRES_OK(context,
context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, in.shape(), FORMAT_NHWC),
&transformed_input));
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
in.tensor<T, 4>(),
transformed_input.tensor<T, 4>());
Tensor transformed_output;
OP_REQUIRES_OK(
context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
&transformed_output));
perftools::gputools::dnn::BatchDescriptor dimensions_desc;
dimensions_desc.set_count(batch)
.set_height(rows)
.set_width(cols)
.set_feature_map_count(depth)
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
normalize_desc.set_bias(bias_)
.set_range(depth_radius_)
.set_alpha(alpha_)
.set_beta(beta_);
auto input_data =
AsDeviceMemory(transformed_input.template flat<T>().data(),
transformed_input.template flat<T>().size());
auto output_data =
AsDeviceMemory(transformed_output.template flat<T>().data(),
transformed_output.template flat<T>().size());
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
bool status =
stream
->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
input_data, &output_data)
.ok();
OP_REQUIRES(context, status,
errors::Internal("NormalizeWithDimensions launch failed"));
// Need to convert it back to NHWC once MIOpen kernels finishes.
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::NCHWToNHWC<GPUDevice, T, 4>()(
context->eigen_device<GPUDevice>(),
toConstTensor(transformed_output).template tensor<T, 4>(),
output->tensor<T, 4>());
#endif
}
int depth_radius_;
@ -225,7 +299,7 @@ struct LaunchLRN<GPUDevice, T> {
T beta_;
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T>
class LRNOp : public OpKernel {
@ -292,7 +366,7 @@ TF_CALL_half(REGISTER_CPU);
#undef REGISTER_CPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
@ -302,7 +376,7 @@ TF_CALL_float(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(IS_MOBILE_PLATFORM)
@ -390,7 +464,7 @@ struct LaunchLRNGrad<CPUDevice, T> {
T alpha_beta_2_;
};
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
struct LaunchLRNGrad<GPUDevice, T> {
@ -400,6 +474,7 @@ struct LaunchLRNGrad<GPUDevice, T> {
void launch(OpKernelContext* context, OpKernel* kernel,
const Tensor& in_grads, const Tensor& in_image,
const Tensor& out_image, Tensor* output) {
#if GOOGLE_CUDA
OP_REQUIRES(
context, beta_ >= 0.01,
errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
@ -447,6 +522,105 @@ struct LaunchLRNGrad<GPUDevice, T> {
OP_REQUIRES(
context, status,
errors::Internal("NormalizeBackwardWithDimensions launch failed"));
#elif TENSORFLOW_USE_ROCM
// For NHWC input/output tensors, convert to NCHW because it's the only
// supported format in MIOpen for now.
const int64 batch = in_grads.dim_size(0);
const int64 rows = in_grads.dim_size(1);
const int64 cols = in_grads.dim_size(2);
const int64 depth = in_grads.dim_size(3);
Tensor transformed_in_grads;
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, in_grads.shape(),
FORMAT_NHWC),
&transformed_in_grads));
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
in_grads.tensor<T, 4>(),
transformed_in_grads.tensor<T, 4>());
Tensor transformed_in_image;
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, in_image.shape(),
FORMAT_NHWC),
&transformed_in_image));
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
in_image.tensor<T, 4>(),
transformed_in_image.tensor<T, 4>());
Tensor transformed_out_image;
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, out_image.shape(),
FORMAT_NHWC),
&transformed_out_image));
functor::NHWCToNCHW<GPUDevice, T, 4>()(
context->eigen_device<GPUDevice>(), out_image.tensor<T, 4>(),
transformed_out_image.tensor<T, 4>());
Tensor transformed_output;
OP_REQUIRES_OK(
context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
&transformed_output));
perftools::gputools::dnn::BatchDescriptor dimensions_desc;
dimensions_desc.set_count(batch)
.set_height(rows)
.set_width(cols)
.set_feature_map_count(depth)
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
normalize_desc.set_bias(bias_)
.set_range(depth_radius_)
.set_alpha(alpha_)
.set_beta(beta_);
auto input_grads_data =
AsDeviceMemory(transformed_in_grads.template flat<T>().data(),
transformed_in_grads.template flat<T>().size());
auto input_image_data =
AsDeviceMemory(transformed_in_image.template flat<T>().data(),
transformed_in_image.template flat<T>().size());
auto output_image_data =
AsDeviceMemory(transformed_out_image.template flat<T>().data(),
transformed_out_image.template flat<T>().size());
auto output_grads_data =
AsDeviceMemory(transformed_output.template flat<T>().data(),
transformed_output.template flat<T>().size());
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
static int64 NormalizeBackwardScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment
// variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize,
context);
bool status = stream
->ThenNormalizeBackwardWithDimensions(
normalize_desc, dimensions_desc, input_image_data,
output_image_data, input_grads_data,
&output_grads_data, &scratch_allocator)
.ok();
OP_REQUIRES(
context, status,
errors::Internal("NormalizeBackwardWithDimensions launch failed"));
// Need to convert it back to NHWC once MIOpen kernels finishes.
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::NCHWToNHWC<GPUDevice, T, 4>()(
context->eigen_device<GPUDevice>(),
toConstTensor(transformed_output).template tensor<T, 4>(),
output->tensor<T, 4>());
#endif
}
int depth_radius_;
@ -455,7 +629,7 @@ struct LaunchLRNGrad<GPUDevice, T> {
T beta_;
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T>
class LRNGradOp : public OpKernel {
@ -524,7 +698,7 @@ TF_CALL_half(REGISTER_CPU);
#undef REGISTER_CPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
@ -534,7 +708,7 @@ TF_CALL_float(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // !defined(IS_MOBILE_PLATFORM)