Merge pull request #40399 from kaixih:pr_cudnn_conv3d_ndhwc
PiperOrigin-RevId: 316954418 Change-Id: I797938817949be483961c560ac20161c42957377
This commit is contained in:
commit
67dd8f02fe
@ -47,6 +47,7 @@ using stream_executor::dnn::DimIndex;
|
|||||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
|
#include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
|
||||||
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
||||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||||
@ -1264,26 +1265,56 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
|
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
|
||||||
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
|
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
|
||||||
<< ", " << padding_planes << ")";
|
<< ", " << padding_planes << ")";
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
const bool compute_in_nhwc =
|
||||||
|
CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
|
||||||
|
#else
|
||||||
|
// fast NDHWC implementation is a CUDA only feature
|
||||||
|
const bool compute_in_nhwc = false;
|
||||||
|
#endif
|
||||||
|
const TensorFormat compute_data_format =
|
||||||
|
(compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
|
||||||
|
: FORMAT_NCHW;
|
||||||
|
|
||||||
|
VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
|
||||||
|
<< " data_format=" << ToString(data_format_)
|
||||||
|
<< " compute_data_format=" << ToString(compute_data_format);
|
||||||
|
|
||||||
|
constexpr auto kComputeInNHWC =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
|
||||||
|
se::dnn::FilterLayout::kOutputYXInput);
|
||||||
|
constexpr auto kComputeInNCHW =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
|
||||||
|
se::dnn::FilterLayout::kOutputInputYX);
|
||||||
|
|
||||||
|
se::dnn::DataLayout compute_data_layout;
|
||||||
|
se::dnn::FilterLayout filter_layout;
|
||||||
|
|
||||||
|
std::tie(compute_data_layout, filter_layout) =
|
||||||
|
compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
|
||||||
|
|
||||||
se::dnn::BatchDescriptor input_desc(3);
|
se::dnn::BatchDescriptor input_desc(3);
|
||||||
input_desc.set_count(dims.batch_size)
|
input_desc.set_count(dims.batch_size)
|
||||||
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
|
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
|
||||||
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
|
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
|
||||||
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
|
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
|
||||||
.set_feature_map_count(dims.in_depth)
|
.set_feature_map_count(dims.in_depth)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::BatchDescriptor output_desc(3);
|
se::dnn::BatchDescriptor output_desc(3);
|
||||||
output_desc.set_count(dims.batch_size)
|
output_desc.set_count(dims.batch_size)
|
||||||
.set_spatial_dim(DimIndex::X, dims.output_size(2))
|
.set_spatial_dim(DimIndex::X, dims.output_size(2))
|
||||||
.set_spatial_dim(DimIndex::Y, dims.output_size(1))
|
.set_spatial_dim(DimIndex::Y, dims.output_size(1))
|
||||||
.set_spatial_dim(DimIndex::Z, dims.output_size(0))
|
.set_spatial_dim(DimIndex::Z, dims.output_size(0))
|
||||||
.set_feature_map_count(dims.out_depth)
|
.set_feature_map_count(dims.out_depth)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::FilterDescriptor filter_desc(3);
|
se::dnn::FilterDescriptor filter_desc(3);
|
||||||
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
||||||
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
||||||
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
||||||
.set_input_feature_map_count(filter_shape.dim_size(3))
|
.set_input_feature_map_count(filter_shape.dim_size(3))
|
||||||
.set_output_feature_map_count(filter_shape.dim_size(4));
|
.set_output_feature_map_count(filter_shape.dim_size(4))
|
||||||
|
.set_layout(filter_layout);
|
||||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||||
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
||||||
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
||||||
@ -1298,21 +1329,28 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
|
|
||||||
// Shape: out, in, z, y, x.
|
// Shape: out, in, z, y, x.
|
||||||
Tensor transformed_filter;
|
Tensor transformed_filter;
|
||||||
OP_REQUIRES_OK(
|
auto dst_format =
|
||||||
context, context->allocate_temp(
|
compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
|
||||||
DataTypeToEnum<T>::value,
|
TensorShape dst_shape =
|
||||||
TensorShape({filter_shape.dim_size(4),
|
dst_format == FORMAT_OIHW
|
||||||
filter_shape.dim_size(3), dims.filter_size(0),
|
? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
|
||||||
dims.filter_size(1), dims.filter_size(2)}),
|
dims.filter_size(0), dims.filter_size(1),
|
||||||
|
dims.filter_size(2)})
|
||||||
|
: TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
|
||||||
|
dims.filter_size(1), dims.filter_size(2),
|
||||||
|
filter_shape.dim_size(3)});
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||||
&transformed_filter));
|
&transformed_filter));
|
||||||
|
|
||||||
functor::TransformFilter<GPUDevice, T, int, 5>()(
|
functor::TransformFilter<GPUDevice, T, int, 5>()(
|
||||||
context->eigen_device<GPUDevice>(), FORMAT_OIHW,
|
context->eigen_device<GPUDevice>(), dst_format,
|
||||||
To32Bit(filter.tensor<T, 5>()),
|
To32Bit(filter.tensor<T, 5>()),
|
||||||
To32Bit(transformed_filter.tensor<T, 5>()));
|
To32Bit(transformed_filter.tensor<T, 5>()));
|
||||||
|
|
||||||
// Shape: batch, filters, z, y, x.
|
// Shape: batch, filters, z, y, x.
|
||||||
Tensor transformed_out_backprop;
|
Tensor transformed_out_backprop;
|
||||||
if (data_format_ == FORMAT_NHWC) {
|
if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
|
TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
|
||||||
dims.output_size(0), dims.output_size(1),
|
dims.output_size(0), dims.output_size(1),
|
||||||
dims.output_size(2)};
|
dims.output_size(2)};
|
||||||
@ -1331,9 +1369,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
}
|
}
|
||||||
// Shape: batch, filters, z, y, x.
|
// Shape: batch, filters, z, y, x.
|
||||||
Tensor pre_transformed_in_backprop;
|
Tensor pre_transformed_in_backprop;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context,
|
context->allocate_temp(
|
||||||
context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
|
DataTypeToEnum<T>::value,
|
||||||
|
ShapeFromFormat(compute_data_format,
|
||||||
|
compatible_input_shape.dim_size(0),
|
||||||
|
{{compatible_input_shape.dim_size(2),
|
||||||
|
compatible_input_shape.dim_size(3),
|
||||||
|
compatible_input_shape.dim_size(4)}},
|
||||||
|
compatible_input_shape.dim_size(1)),
|
||||||
&pre_transformed_in_backprop));
|
&pre_transformed_in_backprop));
|
||||||
|
|
||||||
auto out_backprop_ptr =
|
auto out_backprop_ptr =
|
||||||
@ -1355,7 +1399,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
dims.batch_size,
|
dims.batch_size,
|
||||||
dims.in_depth,
|
dims.in_depth,
|
||||||
{{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
|
{{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
|
||||||
FORMAT_NCHW,
|
compute_data_format,
|
||||||
dims.out_depth,
|
dims.out_depth,
|
||||||
{{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
|
{{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
|
||||||
{{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
|
{{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
|
||||||
@ -1497,11 +1541,13 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
|
|
||||||
if (rows_odd || cols_odd || planes_odd) {
|
if (rows_odd || cols_odd || planes_odd) {
|
||||||
Tensor in_backprop_remove_padding;
|
Tensor in_backprop_remove_padding;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(
|
||||||
context->allocate_temp(
|
context, context->allocate_temp(
|
||||||
DataTypeToEnum<T>::value,
|
DataTypeToEnum<T>::value,
|
||||||
{dims.batch_size, dims.in_depth, dims.input_size(0),
|
ShapeFromFormat(compute_data_format, dims.batch_size,
|
||||||
dims.input_size(1), dims.input_size(2)},
|
{{dims.input_size(0), dims.input_size(1),
|
||||||
|
dims.input_size(2)}},
|
||||||
|
dims.in_depth),
|
||||||
&in_backprop_remove_padding));
|
&in_backprop_remove_padding));
|
||||||
|
|
||||||
// Remove the padding for odd spatial dimensions.
|
// Remove the padding for odd spatial dimensions.
|
||||||
@ -1510,12 +1556,13 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
|
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
|
||||||
.tensor<T, 5>()),
|
.tensor<T, 5>()),
|
||||||
{{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
|
{{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
|
||||||
To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
|
To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
|
||||||
|
compute_data_format);
|
||||||
|
|
||||||
pre_transformed_in_backprop = in_backprop_remove_padding;
|
pre_transformed_in_backprop = in_backprop_remove_padding;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_format_ == FORMAT_NHWC) {
|
if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||||
functor::NCHWToNHWC<GPUDevice, T, 5>()(
|
functor::NCHWToNHWC<GPUDevice, T, 5>()(
|
||||||
context->eigen_device<GPUDevice>(),
|
context->eigen_device<GPUDevice>(),
|
||||||
@ -1723,6 +1770,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
|
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
|
||||||
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
|
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
|
||||||
<< ", " << padding_planes << ")";
|
<< ", " << padding_planes << ")";
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
const bool compute_in_nhwc =
|
||||||
|
CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
|
||||||
|
#else
|
||||||
|
// fast NDHWC implementation is a CUDA only feature
|
||||||
|
const bool compute_in_nhwc = false;
|
||||||
|
#endif
|
||||||
|
const TensorFormat compute_data_format =
|
||||||
|
(compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
|
||||||
|
: FORMAT_NCHW;
|
||||||
|
|
||||||
|
VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
|
||||||
|
<< " data_format=" << ToString(data_format_)
|
||||||
|
<< " compute_data_format=" << ToString(compute_data_format);
|
||||||
|
|
||||||
|
constexpr auto kComputeInNHWC =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
|
||||||
|
se::dnn::FilterLayout::kOutputYXInput);
|
||||||
|
constexpr auto kComputeInNCHW =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
|
||||||
|
se::dnn::FilterLayout::kOutputInputYX);
|
||||||
|
|
||||||
|
se::dnn::DataLayout compute_data_layout;
|
||||||
|
se::dnn::FilterLayout filter_layout;
|
||||||
|
|
||||||
|
std::tie(compute_data_layout, filter_layout) =
|
||||||
|
compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
|
||||||
|
|
||||||
se::dnn::BatchDescriptor input_desc(3);
|
se::dnn::BatchDescriptor input_desc(3);
|
||||||
input_desc.set_count(dims.batch_size)
|
input_desc.set_count(dims.batch_size)
|
||||||
.set_spatial_dim(DimIndex::X,
|
.set_spatial_dim(DimIndex::X,
|
||||||
@ -1732,20 +1808,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
.set_spatial_dim(DimIndex::Z,
|
.set_spatial_dim(DimIndex::Z,
|
||||||
GetTensorDim(compatible_input, data_format_, '0'))
|
GetTensorDim(compatible_input, data_format_, '0'))
|
||||||
.set_feature_map_count(dims.in_depth)
|
.set_feature_map_count(dims.in_depth)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::BatchDescriptor output_desc(3);
|
se::dnn::BatchDescriptor output_desc(3);
|
||||||
output_desc.set_count(dims.batch_size)
|
output_desc.set_count(dims.batch_size)
|
||||||
.set_spatial_dim(DimIndex::X, dims.output_size(2))
|
.set_spatial_dim(DimIndex::X, dims.output_size(2))
|
||||||
.set_spatial_dim(DimIndex::Y, dims.output_size(1))
|
.set_spatial_dim(DimIndex::Y, dims.output_size(1))
|
||||||
.set_spatial_dim(DimIndex::Z, dims.output_size(0))
|
.set_spatial_dim(DimIndex::Z, dims.output_size(0))
|
||||||
.set_feature_map_count(dims.out_depth)
|
.set_feature_map_count(dims.out_depth)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::FilterDescriptor filter_desc(3);
|
se::dnn::FilterDescriptor filter_desc(3);
|
||||||
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
||||||
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
||||||
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
||||||
.set_input_feature_map_count(filter_shape.dim_size(3))
|
.set_input_feature_map_count(filter_shape.dim_size(3))
|
||||||
.set_output_feature_map_count(filter_shape.dim_size(4));
|
.set_output_feature_map_count(filter_shape.dim_size(4))
|
||||||
|
.set_layout(filter_layout);
|
||||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||||
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
||||||
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
||||||
@ -1757,17 +1834,25 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
.set_zero_padding(DimIndex::Y, padding_rows / 2)
|
.set_zero_padding(DimIndex::Y, padding_rows / 2)
|
||||||
.set_zero_padding(DimIndex::Z, padding_planes / 2)
|
.set_zero_padding(DimIndex::Z, padding_planes / 2)
|
||||||
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
|
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
|
||||||
|
|
||||||
Tensor pre_transformed_filter_backprop;
|
Tensor pre_transformed_filter_backprop;
|
||||||
OP_REQUIRES_OK(
|
auto dst_format =
|
||||||
context, context->allocate_temp(
|
compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
|
||||||
DataTypeToEnum<T>::value,
|
TensorShape dst_shape =
|
||||||
TensorShape({filter_shape.dim_size(4),
|
dst_format == FORMAT_OIHW
|
||||||
filter_shape.dim_size(3), dims.filter_size(0),
|
? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
|
||||||
dims.filter_size(1), dims.filter_size(2)}),
|
dims.filter_size(0), dims.filter_size(1),
|
||||||
|
dims.filter_size(2)})
|
||||||
|
: TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
|
||||||
|
dims.filter_size(1), dims.filter_size(2),
|
||||||
|
filter_shape.dim_size(3)});
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||||
&pre_transformed_filter_backprop));
|
&pre_transformed_filter_backprop));
|
||||||
|
|
||||||
Tensor transformed_out_backprop;
|
Tensor transformed_out_backprop;
|
||||||
if (data_format_ == FORMAT_NHWC) {
|
if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
|
VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
|
||||||
TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
|
TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
|
||||||
dims.output_size(0), dims.output_size(1),
|
dims.output_size(0), dims.output_size(1),
|
||||||
dims.output_size(2)};
|
dims.output_size(2)};
|
||||||
@ -1785,7 +1870,8 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
transformed_out_backprop = out_backprop;
|
transformed_out_backprop = out_backprop;
|
||||||
}
|
}
|
||||||
Tensor transformed_input;
|
Tensor transformed_input;
|
||||||
if (data_format_ == FORMAT_NHWC) {
|
if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
|
VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
|
||||||
TensorShape nchw_shape = {
|
TensorShape nchw_shape = {
|
||||||
dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
|
dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
|
||||||
compatible_input.dim_size(2), compatible_input.dim_size(3)};
|
compatible_input.dim_size(2), compatible_input.dim_size(3)};
|
||||||
@ -1823,7 +1909,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
dims.batch_size,
|
dims.batch_size,
|
||||||
dims.in_depth,
|
dims.in_depth,
|
||||||
{{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
|
{{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
|
||||||
FORMAT_NCHW,
|
compute_data_format,
|
||||||
dims.out_depth,
|
dims.out_depth,
|
||||||
{{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
|
{{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
|
||||||
{{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
|
{{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
|
||||||
@ -1947,7 +2033,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
|
|
||||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||||
functor::ReverseTransformFilter<GPUDevice, T, 5>()(
|
functor::ReverseTransformFilter<GPUDevice, T, 5>()(
|
||||||
context->eigen_device<GPUDevice>(), /*src_filter_format=*/FORMAT_OIHW,
|
context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
|
||||||
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
|
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
|
||||||
filter_backprop->tensor<T, 5>());
|
filter_backprop->tensor<T, 5>());
|
||||||
}
|
}
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
using stream_executor::dnn::DimIndex;
|
using stream_executor::dnn::DimIndex;
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||||
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
||||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||||
@ -201,7 +202,23 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_format == FORMAT_NHWC) {
|
#if GOOGLE_CUDA
|
||||||
|
const bool compute_in_nhwc =
|
||||||
|
CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
|
||||||
|
#else
|
||||||
|
// fast NHWC implementation is a CUDA only feature
|
||||||
|
const bool compute_in_nhwc = false;
|
||||||
|
#endif
|
||||||
|
const TensorFormat compute_data_format =
|
||||||
|
(compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
|
||||||
|
: FORMAT_NCHW;
|
||||||
|
|
||||||
|
VLOG(3) << "Compute Conv3D with cuDNN:"
|
||||||
|
<< " data_format=" << ToString(data_format)
|
||||||
|
<< " compute_data_format=" << ToString(compute_data_format);
|
||||||
|
|
||||||
|
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
|
VLOG(4) << "Convert the input tensor from NDHWC to NCDHW.";
|
||||||
const TensorShape nchw_shape = ShapeFromFormat(
|
const TensorShape nchw_shape = ShapeFromFormat(
|
||||||
FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
|
FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
|
||||||
if (in_depth > 1) {
|
if (in_depth > 1) {
|
||||||
@ -219,8 +236,26 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
} else {
|
} else {
|
||||||
CHECK(input.CopyFrom(input, nchw_shape));
|
CHECK(input.CopyFrom(input, nchw_shape));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
CHECK(data_format == compute_data_format) // Crash OK
|
||||||
|
<< "Illegal data and compute format pair:"
|
||||||
|
<< " data_format=" << ToString(data_format)
|
||||||
|
<< " compute_data_format=" << ToString(compute_data_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr auto kComputeInNHWC =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
|
||||||
|
se::dnn::FilterLayout::kOutputYXInput);
|
||||||
|
constexpr auto kComputeInNCHW =
|
||||||
|
std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
|
||||||
|
se::dnn::FilterLayout::kOutputInputYX);
|
||||||
|
|
||||||
|
se::dnn::DataLayout compute_data_layout;
|
||||||
|
se::dnn::FilterLayout filter_layout;
|
||||||
|
|
||||||
|
std::tie(compute_data_layout, filter_layout) =
|
||||||
|
compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
|
||||||
|
|
||||||
CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
|
CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
|
||||||
<< "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
|
<< "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
|
||||||
<< pad_planes << ")";
|
<< pad_planes << ")";
|
||||||
@ -230,20 +265,21 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
.set_spatial_dim(DimIndex::X, in_cols)
|
.set_spatial_dim(DimIndex::X, in_cols)
|
||||||
.set_spatial_dim(DimIndex::Y, in_rows)
|
.set_spatial_dim(DimIndex::Y, in_rows)
|
||||||
.set_spatial_dim(DimIndex::Z, in_planes)
|
.set_spatial_dim(DimIndex::Z, in_planes)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::BatchDescriptor output_desc(3);
|
se::dnn::BatchDescriptor output_desc(3);
|
||||||
output_desc.set_count(in_batch)
|
output_desc.set_count(in_batch)
|
||||||
.set_spatial_dim(DimIndex::X, out_cols)
|
.set_spatial_dim(DimIndex::X, out_cols)
|
||||||
.set_spatial_dim(DimIndex::Y, out_rows)
|
.set_spatial_dim(DimIndex::Y, out_rows)
|
||||||
.set_spatial_dim(DimIndex::Z, out_planes)
|
.set_spatial_dim(DimIndex::Z, out_planes)
|
||||||
.set_feature_map_count(out_depth)
|
.set_feature_map_count(out_depth)
|
||||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
.set_layout(compute_data_layout);
|
||||||
se::dnn::FilterDescriptor filter_desc(3);
|
se::dnn::FilterDescriptor filter_desc(3);
|
||||||
filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
|
filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
|
||||||
.set_spatial_dim(DimIndex::Y, filter_rows)
|
.set_spatial_dim(DimIndex::Y, filter_rows)
|
||||||
.set_spatial_dim(DimIndex::Z, filter_planes)
|
.set_spatial_dim(DimIndex::Z, filter_planes)
|
||||||
.set_input_feature_map_count(filter_depth)
|
.set_input_feature_map_count(filter_depth)
|
||||||
.set_output_feature_map_count(out_depth);
|
.set_output_feature_map_count(out_depth)
|
||||||
|
.set_layout(filter_layout);
|
||||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||||
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
|
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
|
||||||
.set_dilation_rate(DimIndex::Y, dilations[1])
|
.set_dilation_rate(DimIndex::Y, dilations[1])
|
||||||
@ -257,25 +293,41 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
.set_group_count(in_depth / filter_depth);
|
.set_group_count(in_depth / filter_depth);
|
||||||
|
|
||||||
Tensor transformed_filter;
|
Tensor transformed_filter;
|
||||||
OP_REQUIRES_OK(
|
auto dst_format =
|
||||||
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
|
||||||
TensorShape({out_depth, in_depth, filter_planes,
|
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
|
||||||
filter_rows, filter_cols}),
|
<< " to " << ToString(dst_format);
|
||||||
|
TensorShape dst_shape =
|
||||||
|
dst_format == FORMAT_OIHW
|
||||||
|
? TensorShape({filter.dim_size(4), filter.dim_size(3),
|
||||||
|
filter.dim_size(0), filter.dim_size(1),
|
||||||
|
filter.dim_size(2)})
|
||||||
|
: TensorShape({filter.dim_size(4), filter.dim_size(0),
|
||||||
|
filter.dim_size(1), filter.dim_size(2),
|
||||||
|
filter.dim_size(3)});
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||||
&transformed_filter));
|
&transformed_filter));
|
||||||
// filter: [x, y, z, in, out]
|
// filter: [x, y, z, in, out]
|
||||||
// t_filter: [out, in, x, y, z]
|
// t_filter: [out, in, x, y, z] (NCDHW) or
|
||||||
|
// t_filter: [out, x, y, z, in] (NDHWC)
|
||||||
functor::TransformFilter<GPUDevice, T, int, 5>()(
|
functor::TransformFilter<GPUDevice, T, int, 5>()(
|
||||||
ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
|
ctx->eigen_device<GPUDevice>(), dst_format,
|
||||||
To32Bit(filter.tensor<T, 5>()),
|
To32Bit(filter.tensor<T, 5>()),
|
||||||
To32Bit(transformed_filter.tensor<T, 5>()));
|
To32Bit(transformed_filter.tensor<T, 5>()));
|
||||||
|
|
||||||
Tensor transformed_output;
|
Tensor transformed_output;
|
||||||
|
if (data_format != compute_data_format) {
|
||||||
|
VLOG(4) << "Allocate temporary memory for output in compute data format";
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ctx->allocate_temp(
|
ctx,
|
||||||
|
ctx->allocate_temp(
|
||||||
DataTypeToEnum<T>::value,
|
DataTypeToEnum<T>::value,
|
||||||
ShapeFromFormat(FORMAT_NCHW, in_batch,
|
ShapeFromFormat(FORMAT_NCHW, in_batch,
|
||||||
{{out_planes, out_rows, out_cols}}, out_depth),
|
{{out_planes, out_rows, out_cols}}, out_depth),
|
||||||
&transformed_output));
|
&transformed_output));
|
||||||
|
} else {
|
||||||
|
transformed_output = *output;
|
||||||
|
}
|
||||||
|
|
||||||
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||||
input.template flat<T>().size());
|
input.template flat<T>().size());
|
||||||
@ -295,7 +347,7 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
in_batch,
|
in_batch,
|
||||||
in_depth,
|
in_depth,
|
||||||
{{in_planes, in_rows, in_cols}},
|
{{in_planes, in_rows, in_cols}},
|
||||||
FORMAT_NCHW,
|
compute_data_format,
|
||||||
out_depth,
|
out_depth,
|
||||||
{{filter_planes, filter_rows, filter_cols}},
|
{{filter_planes, filter_rows, filter_cols}},
|
||||||
{{dilations[0], dilations[1], dilations[2]}},
|
{{dilations[0], dilations[1], dilations[2]}},
|
||||||
@ -455,15 +507,14 @@ struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
|
|||||||
") filter shape(", filter.shape().DebugString(), ")"));
|
") filter shape(", filter.shape().DebugString(), ")"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_format == FORMAT_NHWC) {
|
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||||
|
VLOG(4) << "Convert the output tensor back from NCDHW to NDHWC.";
|
||||||
// t_output: [b, out, x, y, z]
|
// t_output: [b, out, x, y, z]
|
||||||
// output: [b, x, y, z, out]
|
// output: [b, x, y, z, out]
|
||||||
functor::NCHWToNHWC<GPUDevice, T, 5>()(
|
functor::NCHWToNHWC<GPUDevice, T, 5>()(
|
||||||
ctx->eigen_device<GPUDevice>(),
|
ctx->eigen_device<GPUDevice>(),
|
||||||
const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
|
const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
|
||||||
output->tensor<T, 5>());
|
output->tensor<T, 5>());
|
||||||
} else {
|
|
||||||
*output = transformed_output;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user