From 77e46ebf9745d798863b5c7ed26d6bf077700008 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Mon, 30 Mar 2020 17:25:38 -0700 Subject: [PATCH] conv3d ndhwc plumbing --- tensorflow/core/kernels/conv_grad_ops_3d.cc | 148 ++++++++++++++++---- tensorflow/core/kernels/conv_ops_3d.cc | 88 +++++++++--- 2 files changed, 193 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 2183d0e0885..fc2d58ec94f 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -50,6 +50,7 @@ using stream_executor::dnn::DimIndex; #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" +#include "third_party/gpus/cudnn/cudnn.h" #endif // GOOGLE_CUDA namespace { @@ -1264,26 +1265,56 @@ class Conv3DBackpropInputOp : public OpKernel { CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; + +#if GOOGLE_CUDA + const bool compute_in_nhwc = CUDNN_VERSION >= 8000 && + DataTypeToEnum::value == DT_HALF; +#else + // fast NDHWC implementation is a CUDA only feature + const bool compute_in_nhwc = false; +#endif + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:" + << " data_format=" << ToString(data_format_) + << " compute_data_format=" << ToString(compute_data_format); + + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + se::dnn::BatchDescriptor input_desc(3); input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) .set_feature_map_count(dims.in_depth) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::BatchDescriptor output_desc(3); output_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, dims.output_size(2)) .set_spatial_dim(DimIndex::Y, dims.output_size(1)) .set_spatial_dim(DimIndex::Z, dims.output_size(0)) .set_feature_map_count(dims.out_depth) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::FilterDescriptor filter_desc(3); filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) .set_input_feature_map_count(filter_shape.dim_size(3)) - .set_output_feature_map_count(filter_shape.dim_size(4)); + .set_output_feature_map_count(filter_shape.dim_size(4)) + .set_layout(filter_layout); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) .set_dilation_rate(DimIndex::Y, dims.dilation(1)) @@ -1298,21 +1329,33 @@ class Conv3DBackpropInputOp : public OpKernel { // Shape: out, in, z, y, x. Tensor transformed_filter; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW: FORMAT_OHWI; + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), + dims.filter_size(0), + dims.filter_size(1), + dims.filter_size(2)}) + : TensorShape({filter_shape.dim_size(4), + dims.filter_size(0), + dims.filter_size(1), + dims.filter_size(2), + filter_shape.dim_size(3)}); OP_REQUIRES_OK( context, context->allocate_temp( DataTypeToEnum::value, - TensorShape({filter_shape.dim_size(4), - filter_shape.dim_size(3), dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2)}), + dst_shape, &transformed_filter)); + functor::TransformFilter()( - context->eigen_device(), FORMAT_OIHW, + context->eigen_device(), dst_format, To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); // Shape: batch, filters, z, y, x. Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { + if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { TensorShape nchw_shape = {dims.batch_size, dims.out_depth, dims.output_size(0), dims.output_size(1), dims.output_size(2)}; @@ -1333,8 +1376,15 @@ class Conv3DBackpropInputOp : public OpKernel { Tensor pre_transformed_in_backprop; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum::value, compatible_input_shape, - &pre_transformed_in_backprop)); + context->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat( + compute_data_format, compatible_input_shape.dim_size(0), + {{compatible_input_shape.dim_size(2), + compatible_input_shape.dim_size(3), + compatible_input_shape.dim_size(4)}}, + compatible_input_shape.dim_size(1)), + &pre_transformed_in_backprop)); auto out_backprop_ptr = AsDeviceMemory(transformed_out_backprop.template flat().data(), @@ -1355,7 +1405,7 @@ class Conv3DBackpropInputOp : public OpKernel { dims.batch_size, dims.in_depth, {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, - FORMAT_NCHW, + compute_data_format, dims.out_depth, {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, @@ -1500,8 +1550,11 @@ class Conv3DBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, - {dims.batch_size, dims.in_depth, dims.input_size(0), - dims.input_size(1), dims.input_size(2)}, + ShapeFromFormat( + compute_data_format, dims.batch_size, + {{dims.input_size(0), dims.input_size(1), + dims.input_size(2)}}, + dims.in_depth), &in_backprop_remove_padding)); // Remove the padding for odd spatial dimensions. @@ -1510,12 +1563,13 @@ class Conv3DBackpropInputOp : public OpKernel { To32Bit(const_cast(pre_transformed_in_backprop) .tensor()), {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}}, - To32Bit(in_backprop_remove_padding.tensor()), FORMAT_NCHW); + To32Bit(in_backprop_remove_padding.tensor()), + compute_data_format); pre_transformed_in_backprop = in_backprop_remove_padding; } - if (data_format_ == FORMAT_NHWC) { + if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; functor::NCHWToNHWC()( context->eigen_device(), @@ -1723,6 +1777,35 @@ class Conv3DBackpropFilterOp : public OpKernel { CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; + +#if GOOGLE_CUDA + const bool compute_in_nhwc = CUDNN_VERSION >= 8000 && + DataTypeToEnum::value == DT_HALF; +#else + // fast NDHWC implementation is a CUDA only feature + const bool compute_in_nhwc = false; +#endif + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" + << " data_format=" << ToString(data_format_) + << " compute_data_format=" << ToString(compute_data_format); + + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + se::dnn::BatchDescriptor input_desc(3); input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, @@ -1732,20 +1815,21 @@ class Conv3DBackpropFilterOp : public OpKernel { .set_spatial_dim(DimIndex::Z, GetTensorDim(compatible_input, data_format_, '0')) .set_feature_map_count(dims.in_depth) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::BatchDescriptor output_desc(3); output_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, dims.output_size(2)) .set_spatial_dim(DimIndex::Y, dims.output_size(1)) .set_spatial_dim(DimIndex::Z, dims.output_size(0)) .set_feature_map_count(dims.out_depth) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::FilterDescriptor filter_desc(3); filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) .set_input_feature_map_count(filter_shape.dim_size(3)) - .set_output_feature_map_count(filter_shape.dim_size(4)); + .set_output_feature_map_count(filter_shape.dim_size(4)) + .set_layout(filter_layout); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) .set_dilation_rate(DimIndex::Y, dims.dilation(1)) @@ -1757,17 +1841,30 @@ class Conv3DBackpropFilterOp : public OpKernel { .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2) .set_group_count(dims.in_depth / filter_shape.dim_size(3)); + Tensor pre_transformed_filter_backprop; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW: FORMAT_OHWI; + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), + dims.filter_size(0), + dims.filter_size(1), + dims.filter_size(2)}) + : TensorShape({filter_shape.dim_size(4), + dims.filter_size(0), + dims.filter_size(1), + dims.filter_size(2), + filter_shape.dim_size(3)}); OP_REQUIRES_OK( context, context->allocate_temp( DataTypeToEnum::value, - TensorShape({filter_shape.dim_size(4), - filter_shape.dim_size(3), dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2)}), + dst_shape, &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { + if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = {dims.batch_size, dims.out_depth, dims.output_size(0), dims.output_size(1), dims.output_size(2)}; @@ -1785,7 +1882,8 @@ class Conv3DBackpropFilterOp : public OpKernel { transformed_out_backprop = out_backprop; } Tensor transformed_input; - if (data_format_ == FORMAT_NHWC) { + if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = { dims.batch_size, dims.in_depth, compatible_input.dim_size(1), compatible_input.dim_size(2), compatible_input.dim_size(3)}; @@ -1823,7 +1921,7 @@ class Conv3DBackpropFilterOp : public OpKernel { dims.batch_size, dims.in_depth, {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, - FORMAT_NCHW, + compute_data_format, dims.out_depth, {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, @@ -1947,7 +2045,7 @@ class Conv3DBackpropFilterOp : public OpKernel { auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; functor::ReverseTransformFilter()( - context->eigen_device(), /*src_filter_format=*/FORMAT_OIHW, + context->eigen_device(), /*src_filter_format=*/dst_format, toConstTensor(pre_transformed_filter_backprop).template tensor(), filter_backprop->tensor()); } diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 69e6fba4192..c1fe6c690cd 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -43,6 +43,7 @@ using stream_executor::dnn::DimIndex; #include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" +#include "third_party/gpus/cudnn/cudnn.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -201,7 +202,23 @@ struct LaunchConvOp { } } - if (data_format == FORMAT_NHWC) { +#if GOOGLE_CUDA + const bool compute_in_nhwc = CUDNN_VERSION >= 8000 && + DataTypeToEnum::value == DT_HALF; +#else + // fast NHWC implementation is a CUDA only feature + const bool compute_in_nhwc = false; +#endif + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv3D with cuDNN:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); + + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the input tensor from NDHWC to NCDHW."; const TensorShape nchw_shape = ShapeFromFormat( FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth); if (in_depth > 1) { @@ -219,8 +236,26 @@ struct LaunchConvOp { } else { CHECK(input.CopyFrom(input, nchw_shape)); } + } else { + CHECK(data_format == compute_data_format) // Crash OK + << "Illegal data and compute format pair:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); } + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0) << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", " << pad_planes << ")"; @@ -230,20 +265,21 @@ struct LaunchConvOp { .set_spatial_dim(DimIndex::X, in_cols) .set_spatial_dim(DimIndex::Y, in_rows) .set_spatial_dim(DimIndex::Z, in_planes) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::BatchDescriptor output_desc(3); output_desc.set_count(in_batch) .set_spatial_dim(DimIndex::X, out_cols) .set_spatial_dim(DimIndex::Y, out_rows) .set_spatial_dim(DimIndex::Z, out_planes) .set_feature_map_count(out_depth) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::FilterDescriptor filter_desc(3); filter_desc.set_spatial_dim(DimIndex::X, filter_cols) .set_spatial_dim(DimIndex::Y, filter_rows) .set_spatial_dim(DimIndex::Z, filter_planes) .set_input_feature_map_count(filter_depth) - .set_output_feature_map_count(out_depth); + .set_output_feature_map_count(out_depth) + .set_layout(filter_layout); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) .set_dilation_rate(DimIndex::Y, dilations[1]) @@ -257,25 +293,42 @@ struct LaunchConvOp { .set_group_count(in_depth / filter_depth); Tensor transformed_filter; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW: FORMAT_OHWI; + VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) + << " to " << ToString(dst_format); + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter.dim_size(4), filter.dim_size(3), + filter.dim_size(0), filter.dim_size(1), + filter.dim_size(2)}) + : TensorShape({filter.dim_size(4), filter.dim_size(0), + filter.dim_size(1), filter.dim_size(2), + filter.dim_size(3)}); OP_REQUIRES_OK( ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({out_depth, in_depth, filter_planes, - filter_rows, filter_cols}), + dst_shape, &transformed_filter)); // filter: [x, y, z, in, out] - // t_filter: [out, in, x, y, z] + // t_filter: [out, in, x, y, z] (NCDHW) or + // t_filter: [out, x, y, z, in] (NDHWC) functor::TransformFilter()( - ctx->eigen_device(), FORMAT_OIHW, + ctx->eigen_device(), dst_format, To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); Tensor transformed_output; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, in_batch, - {{out_planes, out_rows, out_cols}}, out_depth), - &transformed_output)); + if (data_format != compute_data_format) { + VLOG(4) << "Allocate temporary memory for output in compute data format"; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NCHW, in_batch, + {{out_planes, out_rows, out_cols}}, out_depth), + &transformed_output)); + } else { + transformed_output = *output; + } auto input_ptr = AsDeviceMemory(input.template flat().data(), input.template flat().size()); @@ -295,7 +348,7 @@ struct LaunchConvOp { in_batch, in_depth, {{in_planes, in_rows, in_cols}}, - FORMAT_NCHW, + compute_data_format, out_depth, {{filter_planes, filter_rows, filter_cols}}, {{dilations[0], dilations[1], dilations[2]}}, @@ -455,15 +508,14 @@ struct LaunchConvOp { ") filter shape(", filter.shape().DebugString(), ")")); } - if (data_format == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the output tensor back from NCDHW to NDHWC."; // t_output: [b, out, x, y, z] // output: [b, x, y, z, out] functor::NCHWToNHWC()( ctx->eigen_device(), const_cast(transformed_output).tensor(), output->tensor()); - } else { - *output = transformed_output; } } };