From ee82131dbccd4e99decb8c05c43bc2bb387ad6ac Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne <reedwm@google.com> Date: Fri, 19 Apr 2019 14:18:19 -0700 Subject: [PATCH] Support explicit padding on CPU for tf.nn.conv2d. PiperOrigin-RevId: 244421216 --- tensorflow/core/kernels/conv_2d.h | 43 +- .../core/kernels/conv_grad_filter_ops.cc | 27 +- .../core/kernels/conv_grad_input_ops.cc | 27 +- tensorflow/core/kernels/conv_ops.cc | 38 +- .../kernels/eigen_spatial_convolutions-inl.h | 86 +- .../python/kernel_tests/conv_ops_test.py | 824 +++++++++--------- tensorflow/python/ops/nn_ops.py | 2 +- 7 files changed, 547 insertions(+), 500 deletions(-) diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 1bac2a18c30..b735f78c2e3 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -57,11 +57,16 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input, Filter filter, int row_stride, int col_stride, int row_dilation, int col_dilation, const Eigen::PaddingType& padding, - const OutputKernel& output_kernel) { - // Need to swap row/col when calling Eigen. - output.device(d) = - Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding, - col_dilation, row_dilation, output_kernel); + const OutputKernel& output_kernel, + int padding_top = 0, int padding_bottom = 0, + int padding_left = 0, int padding_right = 0) { + // Need to swap row/col, padding_top/padding_left, and + // padding_bottom/padding_right when calling Eigen. Eigen expects the tensor + // in NWHC format, but the tensor given is in NHWC. + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, padding, col_dilation, + row_dilation, output_kernel, padding_left, padding_right, padding_top, + padding_bottom); } template <typename Device, typename T, @@ -76,6 +81,18 @@ struct SpatialConvolution { SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, row_dilation, col_dilation, padding, output_kernel); } + void operator()(const Device& d, typename TTypes<T, 4>::Tensor output, + typename TTypes<T, 4>::ConstTensor input, + typename TTypes<T, 4>::ConstTensor filter, int row_stride, + int col_stride, int row_dilation, int col_dilation, + int padding_top, int padding_bottom, int padding_left, + int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc( + d, output, input, filter, row_stride, col_stride, row_dilation, + col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel, + padding_top, padding_bottom, padding_left, padding_right); + } }; template <typename Device, typename OutputKernel> @@ -93,6 +110,22 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> { row_dilation, output_kernel) .template cast<Eigen::half>(); } + void operator()(const Device& d, + typename TTypes<Eigen::half, 4>::Tensor output, + typename TTypes<Eigen::half, 4>::ConstTensor input, + typename TTypes<Eigen::half, 4>::ConstTensor filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution( + input.cast<float>(), filter.cast<float>(), col_stride, row_stride, + Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation, + output_kernel, padding_left, padding_right, padding_top, + padding_bottom) + .template cast<Eigen::half>(); + } }; template <typename Device, typename T> diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 168a91a312a..e755c3e2041 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -208,14 +208,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { errors::InvalidArgument( "Row and column strides should be larger than 0.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES( - context, padding_ != Padding::EXPLICIT, - errors::Unimplemented("Current CPU implementation does not support " - "EXPLICIT padding yet.")); - std::vector<int64> explicit_paddings; OP_REQUIRES_OK(context, - context->GetAttr("explicit_paddings", &explicit_paddings)); - OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings, + context->GetAttr("explicit_paddings", &explicit_paddings_)); + OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, /*num_dims=*/4, data_format_)); OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); OP_REQUIRES(context, dilations_.size() == 4, @@ -247,11 +242,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { filter_sizes.vec<int32>(), &filter_shape)); ConvBackpropDimensions dims; - OP_REQUIRES_OK(context, - ConvBackpropComputeDimensions( - "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, - input.shape(), filter_shape, out_backprop.shape(), - strides_, padding_, data_format_, &dims)); + OP_REQUIRES_OK( + context, + ConvBackpropComputeDimensionsV2( + "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(), + filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1}, + strides_, padding_, explicit_paddings_, data_format_, &dims)); Tensor* filter_backprop; OP_REQUIRES_OK(context, @@ -264,6 +260,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { int64 pad_top, pad_bottom; int64 pad_left, pad_right; + if (padding_ == Padding::EXPLICIT) { + pad_top = explicit_paddings_[2]; + pad_bottom = explicit_paddings_[3]; + pad_left = explicit_paddings_[4]; + pad_right = explicit_paddings_[5]; + } OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( @@ -402,6 +404,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; + std::vector<int64> explicit_paddings_; TensorFormat data_format_; TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp); diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 471c73f65a4..4c1a0d9316b 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -299,14 +299,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel { errors::InvalidArgument( "Current libxsmm and customized CPU implementations do " "not yet support dilation rates larger than 1.")); - OP_REQUIRES( - context, padding_ != Padding::EXPLICIT, - errors::Unimplemented("Current CPU implementation does not support " - "EXPLICIT padding yet.")); - std::vector<int64> explicit_paddings; OP_REQUIRES_OK(context, - context->GetAttr("explicit_paddings", &explicit_paddings)); - OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings, + context->GetAttr("explicit_paddings", &explicit_paddings_)); + OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, /*num_dims=*/4, data_format_)); } @@ -325,10 +320,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel { ConvBackpropDimensions dims; OP_REQUIRES_OK(context, - ConvBackpropComputeDimensions( + ConvBackpropComputeDimensionsV2( "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2, input_shape, filter.shape(), out_backprop.shape(), - strides_, padding_, data_format_, &dims)); + /*dilations=*/{1, 1, 1, 1}, strides_, padding_, + explicit_paddings_, data_format_, &dims)); Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, @@ -375,6 +371,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel { int64 pad_top, pad_bottom; int64 pad_left, pad_right; #endif + if (padding_ == Padding::EXPLICIT) { + pad_top = explicit_paddings_[2]; + pad_bottom = explicit_paddings_[3]; + pad_left = explicit_paddings_[4]; + pad_right = explicit_paddings_[5]; + } OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( @@ -536,6 +538,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel { std::vector<int32> dilations_; std::vector<int32> strides_; Padding padding_; + std::vector<int64> explicit_paddings_; TensorFormat data_format_; TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp); @@ -617,12 +620,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel { use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - if (!std::is_same<Device, GPUDevice>::value) { - OP_REQUIRES( - context, padding_ != Padding::EXPLICIT, - errors::Unimplemented("Current CPU implementation does not support " - "EXPLICIT padding yet.")); - } OP_REQUIRES_OK(context, context->GetAttr("explicit_paddings", &explicit_paddings_)); OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index f5ec3d91f6a..ec54ece9d7c 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -70,11 +70,12 @@ struct LaunchGeneric { void operator()(OpKernelContext* ctx, const Tensor& input, const Tensor& filter, int row_stride, int col_stride, int row_dilation, int col_dilation, const Padding& padding, - Tensor* output, TensorFormat data_format) { + const std::vector<int64>& explicit_paddings, Tensor* output, + TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && - col_stride == 1) { + col_stride == 1 && (padding == SAME || padding == VALID)) { // For 1x1 kernel, the 2D convolution is reduced to matrix // multiplication. // @@ -110,10 +111,20 @@ struct LaunchGeneric { input.shaped<T, 2>({input.dim_size(0), k}), filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair); } else { - functor::SpatialConvolution<Device, T>()( - ctx->eigen_device<Device>(), output->tensor<T, 4>(), - input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride, - row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); + if (padding == EXPLICIT) { + functor::SpatialConvolution<Device, T>()( + ctx->eigen_device<Device>(), output->tensor<T, 4>(), + input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride, + row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]), + static_cast<int>(explicit_paddings[3]), + static_cast<int>(explicit_paddings[4]), + static_cast<int>(explicit_paddings[5])); + } else { + functor::SpatialConvolution<Device, T>()( + ctx->eigen_device<Device>(), output->tensor<T, 4>(), + input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride, + row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); + } } } }; @@ -133,18 +144,19 @@ struct LaunchConv2DOp<CPUDevice, T> { "NHWC tensor format for now.")); return; } - // TODO(reedwm): Enable explicit padding on the CPU. - OP_REQUIRES( - ctx, padding != Padding::EXPLICIT, - errors::Unimplemented("Generic conv implementation does not support " - "EXPLICIT padding yet.")); const int64 in_depth = GetTensorDim(input, data_format, 'C'); OP_REQUIRES(ctx, in_depth == filter.dim_size(2), errors::Unimplemented("Generic conv implementation does not " "support grouped convolutions for now.")); + for (int64 explicit_padding : explicit_paddings) { + if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) { + ctx->SetStatus(errors::InvalidArgument("filter too large")); + return; + } + } LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride, - row_dilation, col_dilation, padding, output, - data_format); + row_dilation, col_dilation, padding, + explicit_paddings, output, data_format); } }; diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h index a2afab42ec1..4559ac3837c 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h @@ -1313,6 +1313,10 @@ struct gemm_pack_rhs< * (aka atrous convolution), sampling every col_in_stride, row_in_stride input * pixels. * + * If padding_top, padding_bottom, padding_left, or padding_right is specified, + * then those paddings will be used to pad the input, and padding_type must be + * PADDING_VALID. + * * The result can be assigned to a tensor of rank equal to the rank of the * input. The dimensions of the result will be filters, height, width (and * others if applicable). @@ -1360,7 +1364,9 @@ EIGEN_DEVICE_FUNC const PaddingType padding_type = PADDING_SAME, const Index row_in_stride = 1, const Index col_in_stride = 1, - const OutputKernel& output_kernel = OutputKernel()) { + const OutputKernel& output_kernel = OutputKernel(), + Index padding_top = 0, Index padding_bottom = 0, + Index padding_left = 0, Index padding_right = 0) { typedef typename internal::traits<Input>::Index TensorIndex; TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, @@ -1402,25 +1408,33 @@ EIGEN_DEVICE_FUNC isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + const bool padding_explicit = + (padding_top || padding_bottom || padding_left || padding_right); TensorIndex out_height; TensorIndex out_width; switch (padding_type) { - case PADDING_VALID: - out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / + case PADDING_VALID: { + const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom; + const TensorIndex InputColsEff = InputCols + padding_left + padding_right; + out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) / static_cast<float>(row_stride)); - out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / + out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) / static_cast<float>(col_stride)); break; - case PADDING_SAME: + } + case PADDING_SAME: { + eigen_assert(!padding_explicit); out_height = numext::ceil(InputRows / static_cast<float>(row_stride)); out_width = numext::ceil(InputCols / static_cast<float>(col_stride)); break; - default: + } + default: { // Initialize unused variables to avoid a compiler warning out_height = 0; out_width = 0; eigen_assert(false && "unexpected padding"); + } } // Molds the output of the patch extraction code into a 2d tensor: @@ -1473,22 +1487,50 @@ EIGEN_DEVICE_FUNC kernel_dims[0] = kernelChannels * kernelRows * kernelCols; kernel_dims[1] = kernelFilters; } - return choose( - Cond<internal::traits<Input>::Layout == ColMajor>(), - kernel.reshape(kernel_dims) - .contract(input - .extract_image_patches( - kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims), - contract_dims, output_kernel) - .reshape(post_contract_dims), - input - .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims) - .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) - .reshape(post_contract_dims)); + if (padding_explicit) { + return choose( + Cond<internal::traits<Input>::Layout == ColMajor>(), + kernel.reshape(kernel_dims) + .contract(input + .extract_image_patches( + kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, + /*row_inflate_stride=*/1, + /*col_inflate_stride=*/1, padding_top, + padding_bottom, padding_left, padding_right, + /*padding_value=*/0) + .reshape(pre_contract_dims), + contract_dims, output_kernel) + .reshape(post_contract_dims), + input + .extract_image_patches(kernelRows, kernelCols, row_stride, + col_stride, row_in_stride, col_in_stride, + /*row_inflate_stride=*/1, + /*col_inflate_stride=*/1, padding_top, + padding_bottom, padding_left, padding_right, + /*padding_value=*/0) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) + .reshape(post_contract_dims)); + } else { + return choose( + Cond<internal::traits<Input>::Layout == ColMajor>(), + kernel.reshape(kernel_dims) + .contract(input + .extract_image_patches( + kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims), + contract_dims, output_kernel) + .reshape(post_contract_dims), + input + .extract_image_patches(kernelRows, kernelCols, row_stride, + col_stride, row_in_stride, col_in_stride, + padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) + .reshape(post_contract_dims)); + } } } // end namespace Eigen diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index b7290497702..0bec67f5213 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -403,7 +403,6 @@ class Conv2DTest(test.TestCase): padding, expected, dilations, - gpu_only=True, test_grappler_layout_optimizer=test_grappler_layout_optimizer, tol=tol, fp16_tol=fp16_tol) @@ -1429,8 +1428,14 @@ class Conv2DTest(test.TestCase): strides, padding, data_format, + use_gpu, dilations=(1, 1), err=2e-5): + if use_gpu and not test.is_gpu_available(cuda_only=True): + return + if not use_gpu and dilations != (1, 1): + return # Non-default dilations is currently not supported on the CPU. + x1 = self._CreateNumpyTensor(filter_sizes) x2 = self._CreateNumpyTensor(output_sizes) dilations = list(dilations) @@ -1455,133 +1460,128 @@ class Conv2DTest(test.TestCase): padding, expected, data_format, - use_gpu=True, + use_gpu=use_gpu, err=err, dilations=dilations) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding0x0BackpropInput(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 1, 2, 1], - strides=[1, 1], - padding=[[0, 0], [0, 0]], - data_format=data_format) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 2, 1], + strides=[1, 1], + padding=[[0, 0], [0, 0]], + data_format=data_format, + use_gpu=use_gpu) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 3, 4, 2], - filter_sizes=[2, 2, 2, 3], - output_sizes=[1, 1, 2, 3], - strides=[2, 2], - padding=[[0, 0], [0, 0]], - data_format=data_format) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 3, 4, 2], + filter_sizes=[2, 2, 2, 3], + output_sizes=[1, 1, 2, 3], + strides=[2, 2], + padding=[[0, 0], [0, 0]], + data_format=data_format, + use_gpu=use_gpu) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding1x1BackpropInput(self): - if not test.is_gpu_available(cuda_only=True): - return - for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 2], - output_sizes=[1, 3, 4, 2], - strides=[1, 1], - padding=[[1, 1], [1, 1]], - data_format=data_format, err=1e-4) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 2], + output_sizes=[1, 3, 4, 2], + strides=[1, 1], + padding=[[1, 1], [1, 1]], + data_format=data_format, + use_gpu=use_gpu, + err=1e-4) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 2, 3, 2], - filter_sizes=[1, 1, 2, 1], - output_sizes=[1, 4, 3, 1], - strides=[1, 2], - padding=[[1, 1], [1, 1]], - data_format=data_format) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 2, 3, 2], + filter_sizes=[1, 1, 2, 1], + output_sizes=[1, 4, 3, 1], + strides=[1, 2], + padding=[[1, 1], [1, 1]], + data_format=data_format, + use_gpu=use_gpu) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 4, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 4, 2, 1], - strides=[1, 2], - padding=[[1, 1], [1, 1]], - data_format=data_format, - dilations=[2, 2]) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 4, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 4, 2, 1], + strides=[1, 2], + padding=[[1, 1], [1, 1]], + data_format=data_format, + dilations=[2, 2], use_gpu=use_gpu) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding2x2BackpropInput(self): - if not test.is_gpu_available(cuda_only=True): - return - for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[2, 3, 1, 1], - filter_sizes=[2, 1, 1, 1], - output_sizes=[2, 2, 5, 1], - strides=[3, 1], - padding=[[2, 2], [2, 2]], - data_format=data_format) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[2, 3, 1, 1], + filter_sizes=[2, 1, 1, 1], + output_sizes=[2, 2, 5, 1], + strides=[3, 1], + padding=[[2, 2], [2, 2]], + data_format=data_format, + use_gpu=use_gpu) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 3, 6, 1], - filter_sizes=[3, 2, 1, 1], - output_sizes=[1, 3, 4, 1], - strides=[1, 2], - padding=[[2, 2], [2, 2]], - data_format=data_format, - dilations=[2, 3]) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 3, 6, 1], + filter_sizes=[3, 2, 1, 1], + output_sizes=[1, 3, 4, 1], + strides=[1, 2], + padding=[[2, 2], [2, 2]], + data_format=data_format, + dilations=[2, 3], + use_gpu=use_gpu) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding_1_8_4_1_BackpropInput(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 10, 8, 1], - strides=[1, 1], - padding=[[1, 8], [4, 2]], - data_format=data_format, err=5e-5) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 10, 8, 1], + strides=[1, 1], + padding=[[1, 8], [4, 2]], + data_format=data_format, + use_gpu=use_gpu, + err=5e-5) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 5, 3, 1], - filter_sizes=[3, 2, 1, 1], - output_sizes=[1, 4, 8, 1], - strides=[3, 1], - padding=[[1, 8], [4, 2]], - data_format=data_format) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 5, 3, 1], + filter_sizes=[3, 2, 1, 1], + output_sizes=[1, 4, 8, 1], + strides=[3, 1], + padding=[[1, 8], [4, 2]], + data_format=data_format, + use_gpu=use_gpu) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding_5_0_2_2_BackpropInput(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 3, 3, 1], - filter_sizes=[2, 1, 1, 1], - output_sizes=[1, 7, 7, 1], - strides=[1, 1], - padding=[[5, 0], [2, 2]], - data_format=data_format, - err=5e-5) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 1, 1, 1], + output_sizes=[1, 7, 7, 1], + strides=[1, 1], + padding=[[5, 0], [2, 2]], + data_format=data_format, + err=5e-5, + use_gpu=use_gpu) - self._RunAndVerifyBackpropInputExplicitPadding( - input_sizes=[1, 4, 2, 1], - filter_sizes=[3, 3, 1, 1], - output_sizes=[1, 5, 2, 1], - strides=[1, 2], - padding=[[5, 0], [2, 2]], - data_format=data_format, - dilations=[2, 1]) + self._RunAndVerifyBackpropInputExplicitPadding( + input_sizes=[1, 4, 2, 1], + filter_sizes=[3, 3, 1, 1], + output_sizes=[1, 5, 2, 1], + strides=[1, 2], + padding=[[5, 0], [2, 2]], + data_format=data_format, + dilations=[2, 1], + use_gpu=use_gpu) def _RunAndVerifyBackpropFilterExplicitPadding(self, input_sizes, @@ -1590,8 +1590,14 @@ class Conv2DTest(test.TestCase): strides, padding, data_format, + use_gpu, dilations=(1, 1), err=1e-5): + if use_gpu and not test.is_gpu_available(cuda_only=True): + return + if not use_gpu and dilations != (1, 1): + return # Non-default dilations is currently not supported on the CPU. + x0 = self._CreateNumpyTensor(input_sizes) x2 = self._CreateNumpyTensor(output_sizes) dilations = list(dilations) @@ -1613,135 +1619,127 @@ class Conv2DTest(test.TestCase): padding, expected, data_format, - use_gpu=True, + use_gpu=use_gpu, dilations=dilations, err=err) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding0x0BackpropFilter(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 1, 2, 1], - strides=[1, 1], - padding=[[0, 0], [0, 0]], - data_format=data_format) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 1, 2, 1], + strides=[1, 1], + padding=[[0, 0], [0, 0]], + data_format=data_format, use_gpu=use_gpu) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 3, 4, 2], - filter_sizes=[2, 2, 2, 3], - output_sizes=[1, 1, 2, 3], - strides=[2, 2], - padding=[[0, 0], [0, 0]], - data_format=data_format) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 3, 4, 2], + filter_sizes=[2, 2, 2, 3], + output_sizes=[1, 1, 2, 3], + strides=[2, 2], + padding=[[0, 0], [0, 0]], + data_format=data_format, use_gpu=use_gpu) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding1x1BackpropFilter(self): - if not test.is_gpu_available(cuda_only=True): - return - for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 2], - output_sizes=[1, 3, 4, 2], - strides=[1, 1], - padding=[[1, 1], [1, 1]], - data_format=data_format, - err=5e-5) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 2], + output_sizes=[1, 3, 4, 2], + strides=[1, 1], + padding=[[1, 1], [1, 1]], + data_format=data_format, + use_gpu=use_gpu, + err=5e-5) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 2, 3, 2], - filter_sizes=[1, 1, 2, 1], - output_sizes=[1, 4, 3, 1], - strides=[1, 2], - padding=[[1, 1], [1, 1]], - data_format=data_format) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 2, 3, 2], + filter_sizes=[1, 1, 2, 1], + output_sizes=[1, 4, 3, 1], + strides=[1, 2], + padding=[[1, 1], [1, 1]], + use_gpu=use_gpu, + data_format=data_format) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 4, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 4, 2, 1], - strides=[1, 2], - padding=[[1, 1], [1, 1]], - data_format=data_format, - dilations=[2, 2]) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 4, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 4, 2, 1], + strides=[1, 2], + padding=[[1, 1], [1, 1]], + data_format=data_format, + use_gpu=use_gpu, + dilations=[2, 2]) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding2x2BackpropFilter(self): - if not test.is_gpu_available(cuda_only=True): - return - for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[2, 3, 1, 1], - filter_sizes=[2, 1, 1, 1], - output_sizes=[2, 2, 5, 1], - strides=[3, 1], - padding=[[2, 2], [2, 2]], - data_format=data_format) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[2, 3, 1, 1], + filter_sizes=[2, 1, 1, 1], + output_sizes=[2, 2, 5, 1], + strides=[3, 1], + padding=[[2, 2], [2, 2]], + data_format=data_format, + use_gpu=use_gpu) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 3, 6, 1], - filter_sizes=[3, 2, 1, 1], - output_sizes=[1, 3, 4, 1], - strides=[1, 2], - padding=[[2, 2], [2, 2]], - data_format=data_format, - dilations=[2, 3]) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 3, 6, 1], + filter_sizes=[3, 2, 1, 1], + output_sizes=[1, 3, 4, 1], + strides=[1, 2], + padding=[[2, 2], [2, 2]], + data_format=data_format, + use_gpu=use_gpu, + dilations=[2, 3]) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding_1_8_4_1_BackpropFilter(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 2, 3, 1], - filter_sizes=[2, 2, 1, 1], - output_sizes=[1, 10, 8, 1], - strides=[1, 1], - padding=[[1, 8], [4, 2]], - data_format=data_format, - err=1e-4) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[1, 10, 8, 1], + strides=[1, 1], + padding=[[1, 8], [4, 2]], + data_format=data_format, + use_gpu=use_gpu, + err=1e-4) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 5, 3, 1], - filter_sizes=[3, 2, 1, 1], - output_sizes=[1, 4, 8, 1], - strides=[3, 1], - padding=[[1, 8], [4, 2]], - data_format=data_format) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 5, 3, 1], + filter_sizes=[3, 2, 1, 1], + output_sizes=[1, 4, 8, 1], + strides=[3, 1], + padding=[[1, 8], [4, 2]], + use_gpu=use_gpu, + data_format=data_format) @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1Padding_5_0_2_2_BackpropFilter(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 3, 3, 1], - filter_sizes=[2, 1, 1, 1], - output_sizes=[1, 7, 7, 1], - strides=[1, 1], - padding=[[5, 0], [2, 2]], - data_format=data_format, - err=1e-4) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 1, 1, 1], + output_sizes=[1, 7, 7, 1], + strides=[1, 1], + padding=[[5, 0], [2, 2]], + data_format=data_format, + use_gpu=use_gpu, + err=1e-4) - self._RunAndVerifyBackpropFilterExplicitPadding( - input_sizes=[1, 4, 2, 1], - filter_sizes=[3, 3, 1, 1], - output_sizes=[1, 5, 2, 1], - strides=[1, 2], - padding=[[5, 0], [2, 2]], - data_format=data_format, - dilations=[2, 1]) + self._RunAndVerifyBackpropFilterExplicitPadding( + input_sizes=[1, 4, 2, 1], + filter_sizes=[3, 3, 1, 1], + output_sizes=[1, 5, 2, 1], + strides=[1, 2], + padding=[[5, 0], [2, 2]], + data_format=data_format, + use_gpu=use_gpu, + dilations=[2, 1]) # Gradient checkers def ConstructAndTestGradient(self, @@ -2107,257 +2105,221 @@ class Conv2DTest(test.TestCase): @test_util.deprecated_graph_mode_only def testInputGradient1x1PaddingStrideOne(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=5, - input_cols=4, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=1, - stride_cols=1, - padding=[[0, 0], [1, 1], [1, 1], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu, - max_err=0.0025) + self.ConstructAndTestGradient( + batch=2, + input_rows=5, + input_cols=4, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=1, + stride_cols=1, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu, + max_err=0.0025) @test_util.deprecated_graph_mode_only def testFilterGradient1x1PaddingStrideOne(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=5, - input_cols=4, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=1, - stride_cols=1, - padding=[[0, 0], [1, 1], [1, 1], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=5, + input_cols=4, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=1, + stride_cols=1, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testInputGradient1x1PaddingStrideTwo(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=4, - input_cols=5, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=2, - stride_cols=2, - padding=[[0, 0], [1, 1], [1, 1], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=4, + input_cols=5, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=2, + stride_cols=2, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testFilterGradient1x1PaddingStrideTwo(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=4, - input_cols=5, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=2, - stride_cols=2, - padding=[[0, 0], [1, 1], [1, 1], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=4, + input_cols=5, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=2, + stride_cols=2, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testInputGradient2x2PaddingStrideOne(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=5, - input_cols=4, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=1, - stride_cols=1, - padding=[[0, 0], [2, 2], [2, 2], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=5, + input_cols=4, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=1, + stride_cols=1, + padding=[[0, 0], [2, 2], [2, 2], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testFilterGradient2x2PaddingStrideOne(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=5, - input_cols=4, - filter_rows=3, - filter_cols=3, - in_depth=2, - out_depth=3, - stride_rows=1, - stride_cols=1, - padding=[[0, 0], [2, 2], [2, 2], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu, - max_err=0.003) + self.ConstructAndTestGradient( + batch=2, + input_rows=5, + input_cols=4, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride_rows=1, + stride_cols=1, + padding=[[0, 0], [2, 2], [2, 2], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu, + max_err=0.003) @test_util.deprecated_graph_mode_only def testInputGradient1_2_3_4PaddingStride3x2(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=8, - input_cols=5, - filter_rows=4, - filter_cols=2, - in_depth=3, - out_depth=2, - stride_rows=3, - stride_cols=2, - padding=[[0, 0], [1, 2], [3, 4], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=8, + input_cols=5, + filter_rows=4, + filter_cols=2, + in_depth=3, + out_depth=2, + stride_rows=3, + stride_cols=2, + padding=[[0, 0], [1, 2], [3, 4], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testFilterGradient1_2_3_4PaddingStride3x2(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=8, - input_cols=5, - filter_rows=4, - filter_cols=2, - in_depth=3, - out_depth=2, - stride_rows=3, - stride_cols=2, - padding=[[0, 0], [1, 2], [3, 4], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=8, + input_cols=5, + filter_rows=4, + filter_cols=2, + in_depth=3, + out_depth=2, + stride_rows=3, + stride_cols=2, + padding=[[0, 0], [1, 2], [3, 4], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testInputGradient4_3_2_1PaddingStride2x1(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=3, - input_rows=5, - input_cols=7, - filter_rows=3, - filter_cols=2, - in_depth=1, - out_depth=2, - stride_rows=2, - stride_cols=1, - padding=[[0, 0], [4, 3], [2, 1], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=3, + input_rows=5, + input_cols=7, + filter_rows=3, + filter_cols=2, + in_depth=1, + out_depth=2, + stride_rows=2, + stride_cols=1, + padding=[[0, 0], [4, 3], [2, 1], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testFilterGradient4_3_2_1PaddingStride2x1(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=3, - input_rows=5, - input_cols=7, - filter_rows=3, - filter_cols=2, - in_depth=1, - out_depth=2, - stride_rows=2, - stride_cols=1, - padding=[[0, 0], [4, 3], [2, 1], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=3, + input_rows=5, + input_cols=7, + filter_rows=3, + filter_cols=2, + in_depth=1, + out_depth=2, + stride_rows=2, + stride_cols=1, + padding=[[0, 0], [4, 3], [2, 1], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testInputGradient0_0_0_5PaddingStride1x2(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=6, - input_cols=7, - filter_rows=3, - filter_cols=4, - in_depth=3, - out_depth=2, - stride_rows=1, - stride_cols=2, - padding=[[0, 0], [0, 0], [0, 5], [0, 0]], - test_input=True, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=6, + input_cols=7, + filter_rows=3, + filter_cols=4, + in_depth=3, + out_depth=2, + stride_rows=1, + stride_cols=2, + padding=[[0, 0], [0, 0], [0, 5], [0, 0]], + test_input=True, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testFilterGradient0_0_0_5PaddingStride1x2(self): - if not test.is_gpu_available(cuda_only=True): - return for (data_format, use_gpu) in GetTestConfigs(): - if use_gpu: - self.ConstructAndTestGradient( - batch=2, - input_rows=6, - input_cols=7, - filter_rows=3, - filter_cols=4, - in_depth=3, - out_depth=2, - stride_rows=1, - stride_cols=2, - padding=[[0, 0], [0, 0], [0, 5], [0, 0]], - test_input=False, - data_format=data_format, - use_gpu=use_gpu) + self.ConstructAndTestGradient( + batch=2, + input_rows=6, + input_cols=7, + filter_rows=3, + filter_cols=4, + in_depth=3, + out_depth=2, + stride_rows=1, + stride_cols=2, + padding=[[0, 0], [0, 0], [0, 5], [0, 0]], + test_input=False, + data_format=data_format, + use_gpu=use_gpu) @test_util.deprecated_graph_mode_only def testShapeFunctionEdgeCases(self): @@ -2505,31 +2467,29 @@ class Conv2DTest(test.TestCase): strides=[1, 1, 1, 1], padding=[[0, 0], [2, 2], [2, 2], [0, 0]])) - if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True): - # Negative padding during backprop. - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "nonnegative"): - sess.run( - nn_ops.conv2d_backprop_input([32, 20, 20, 3], - array_ops.placeholder( - dtypes.float32, - shape=[18, 18, 3, 2]), - array_ops.placeholder( - dtypes.float32, - shape=[32, 3, 2, 2]), - strides=[1, 1, 1, 1], - padding=[[0, 0], [-1, 0], [0, 0], - [0, 0]])) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "nonnegative"): - sess.run( - nn_ops.conv2d_backprop_filter( - array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), - [18, 18, 3, 2], - array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]), - strides=[1, 1, 1, 1], - padding=[[0, 0], [-1, 0], [0, 0], [0, 0]])) + # Negative padding during backprop. + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "nonnegative"): + sess.run( + nn_ops.conv2d_backprop_input([32, 20, 20, 3], + array_ops.placeholder( + dtypes.float32, + shape=[18, 18, 3, 2]), + array_ops.placeholder( + dtypes.float32, + shape=[32, 3, 2, 2]), + strides=[1, 1, 1, 1], + padding=[[0, 0], [-1, 0], [0, 0], + [0, 0]])) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "nonnegative"): + sess.run( + nn_ops.conv2d_backprop_filter( + array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), + [18, 18, 3, 2], + array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]), + strides=[1, 1, 1, 1], + padding=[[0, 0], [-1, 0], [0, 0], [0, 0]])) class DepthwiseConv2DTest(test.TestCase): diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 97cbe55403e..50583c2a893 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1904,7 +1904,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value value is given it is replicated in the `H` and `W` dimension. By default the `N` and `C` dimensions are set to 1. The dimension order is determined by the value of `data_format`, see below for details. - padding: Either the `string `"SAME"` or `"VALID"` indicating the type of + padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of padding algorithm to use, or a list indicating the explicit paddings at the start and end of each dimension. When explicit padding is used and data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,