Adding ROCm support for the LRN op
This commit is contained in:
parent
ec71eb1ac4
commit
3c2220abe7
@ -4366,7 +4366,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "lrn_op",
|
||||
prefix = "lrn_op",
|
||||
deps = NN_DEPS,
|
||||
deps = NN_DEPS + if_rocm([":conv_ops_gpu_hdrs"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -36,9 +36,17 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#endif
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -164,7 +172,7 @@ struct LaunchLRN<CPUDevice, T> {
|
||||
T beta_;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
struct LaunchLRN<GPUDevice, T> {
|
||||
@ -173,6 +181,7 @@ struct LaunchLRN<GPUDevice, T> {
|
||||
|
||||
void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
|
||||
Tensor* output) {
|
||||
#if GOOGLE_CUDA
|
||||
OP_REQUIRES(
|
||||
context, beta_ >= 0.01,
|
||||
errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
|
||||
@ -217,6 +226,71 @@ struct LaunchLRN<GPUDevice, T> {
|
||||
.ok();
|
||||
OP_REQUIRES(context, status,
|
||||
errors::Internal("NormalizeWithDimensions launch failed"));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
// For NHWC input/output tensors, convert to NCHW because it's the only
|
||||
// supported format in MIOpen for now.
|
||||
|
||||
// Cast to platform-specific int to avoid conversion warnings.
|
||||
const int batch = static_cast<int>(in.dim_size(0));
|
||||
const int rows = static_cast<int>(in.dim_size(1));
|
||||
const int cols = static_cast<int>(in.dim_size(2));
|
||||
const int depth = static_cast<int>(in.dim_size(3));
|
||||
|
||||
Tensor transformed_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, in.shape(), FORMAT_NHWC),
|
||||
&transformed_input));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
|
||||
in.tensor<T, 4>(),
|
||||
transformed_input.tensor<T, 4>());
|
||||
|
||||
Tensor transformed_output;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
|
||||
&transformed_output));
|
||||
|
||||
perftools::gputools::dnn::BatchDescriptor dimensions_desc;
|
||||
dimensions_desc.set_count(batch)
|
||||
.set_height(rows)
|
||||
.set_width(cols)
|
||||
.set_feature_map_count(depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
|
||||
perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
|
||||
normalize_desc.set_bias(bias_)
|
||||
.set_range(depth_radius_)
|
||||
.set_alpha(alpha_)
|
||||
.set_beta(beta_);
|
||||
|
||||
auto input_data =
|
||||
AsDeviceMemory(transformed_input.template flat<T>().data(),
|
||||
transformed_input.template flat<T>().size());
|
||||
auto output_data =
|
||||
AsDeviceMemory(transformed_output.template flat<T>().data(),
|
||||
transformed_output.template flat<T>().size());
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
bool status =
|
||||
stream
|
||||
->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
|
||||
input_data, &output_data)
|
||||
.ok();
|
||||
OP_REQUIRES(context, status,
|
||||
errors::Internal("NormalizeWithDimensions launch failed"));
|
||||
|
||||
// Need to convert it back to NHWC once MIOpen kernels finishes.
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::NCHWToNHWC<GPUDevice, T, 4>()(
|
||||
context->eigen_device<GPUDevice>(),
|
||||
toConstTensor(transformed_output).template tensor<T, 4>(),
|
||||
output->tensor<T, 4>());
|
||||
#endif
|
||||
}
|
||||
|
||||
int depth_radius_;
|
||||
@ -225,7 +299,7 @@ struct LaunchLRN<GPUDevice, T> {
|
||||
T beta_;
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class LRNOp : public OpKernel {
|
||||
@ -292,7 +366,7 @@ TF_CALL_half(REGISTER_CPU);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -302,7 +376,7 @@ TF_CALL_float(REGISTER_GPU);
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
|
||||
@ -390,7 +464,7 @@ struct LaunchLRNGrad<CPUDevice, T> {
|
||||
T alpha_beta_2_;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
struct LaunchLRNGrad<GPUDevice, T> {
|
||||
@ -400,6 +474,7 @@ struct LaunchLRNGrad<GPUDevice, T> {
|
||||
void launch(OpKernelContext* context, OpKernel* kernel,
|
||||
const Tensor& in_grads, const Tensor& in_image,
|
||||
const Tensor& out_image, Tensor* output) {
|
||||
#if GOOGLE_CUDA
|
||||
OP_REQUIRES(
|
||||
context, beta_ >= 0.01,
|
||||
errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
|
||||
@ -447,6 +522,105 @@ struct LaunchLRNGrad<GPUDevice, T> {
|
||||
OP_REQUIRES(
|
||||
context, status,
|
||||
errors::Internal("NormalizeBackwardWithDimensions launch failed"));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
// For NHWC input/output tensors, convert to NCHW because it's the only
|
||||
// supported format in MIOpen for now.
|
||||
const int64 batch = in_grads.dim_size(0);
|
||||
const int64 rows = in_grads.dim_size(1);
|
||||
const int64 cols = in_grads.dim_size(2);
|
||||
const int64 depth = in_grads.dim_size(3);
|
||||
|
||||
Tensor transformed_in_grads;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, in_grads.shape(),
|
||||
FORMAT_NHWC),
|
||||
&transformed_in_grads));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
|
||||
in_grads.tensor<T, 4>(),
|
||||
transformed_in_grads.tensor<T, 4>());
|
||||
|
||||
Tensor transformed_in_image;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, in_image.shape(),
|
||||
FORMAT_NHWC),
|
||||
&transformed_in_image));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
|
||||
in_image.tensor<T, 4>(),
|
||||
transformed_in_image.tensor<T, 4>());
|
||||
|
||||
Tensor transformed_out_image;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, out_image.shape(),
|
||||
FORMAT_NHWC),
|
||||
&transformed_out_image));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(
|
||||
context->eigen_device<GPUDevice>(), out_image.tensor<T, 4>(),
|
||||
transformed_out_image.tensor<T, 4>());
|
||||
|
||||
Tensor transformed_output;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
|
||||
&transformed_output));
|
||||
|
||||
perftools::gputools::dnn::BatchDescriptor dimensions_desc;
|
||||
dimensions_desc.set_count(batch)
|
||||
.set_height(rows)
|
||||
.set_width(cols)
|
||||
.set_feature_map_count(depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
|
||||
perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
|
||||
normalize_desc.set_bias(bias_)
|
||||
.set_range(depth_radius_)
|
||||
.set_alpha(alpha_)
|
||||
.set_beta(beta_);
|
||||
|
||||
auto input_grads_data =
|
||||
AsDeviceMemory(transformed_in_grads.template flat<T>().data(),
|
||||
transformed_in_grads.template flat<T>().size());
|
||||
auto input_image_data =
|
||||
AsDeviceMemory(transformed_in_image.template flat<T>().data(),
|
||||
transformed_in_image.template flat<T>().size());
|
||||
auto output_image_data =
|
||||
AsDeviceMemory(transformed_out_image.template flat<T>().data(),
|
||||
transformed_out_image.template flat<T>().size());
|
||||
auto output_grads_data =
|
||||
AsDeviceMemory(transformed_output.template flat<T>().data(),
|
||||
transformed_output.template flat<T>().size());
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
static int64 NormalizeBackwardScratchSize = GetDnnWorkspaceLimit(
|
||||
// default value is in bytes despite the name of the environment
|
||||
// variable
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||
);
|
||||
|
||||
DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize,
|
||||
context);
|
||||
bool status = stream
|
||||
->ThenNormalizeBackwardWithDimensions(
|
||||
normalize_desc, dimensions_desc, input_image_data,
|
||||
output_image_data, input_grads_data,
|
||||
&output_grads_data, &scratch_allocator)
|
||||
.ok();
|
||||
OP_REQUIRES(
|
||||
context, status,
|
||||
errors::Internal("NormalizeBackwardWithDimensions launch failed"));
|
||||
|
||||
// Need to convert it back to NHWC once MIOpen kernels finishes.
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::NCHWToNHWC<GPUDevice, T, 4>()(
|
||||
context->eigen_device<GPUDevice>(),
|
||||
toConstTensor(transformed_output).template tensor<T, 4>(),
|
||||
output->tensor<T, 4>());
|
||||
#endif
|
||||
}
|
||||
|
||||
int depth_radius_;
|
||||
@ -455,7 +629,7 @@ struct LaunchLRNGrad<GPUDevice, T> {
|
||||
T beta_;
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class LRNGradOp : public OpKernel {
|
||||
@ -524,7 +698,7 @@ TF_CALL_half(REGISTER_CPU);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -534,7 +708,7 @@ TF_CALL_float(REGISTER_GPU);
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // !defined(IS_MOBILE_PLATFORM)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user