diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 2142207b0d8..6949e5b5fd8 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -54,10 +54,12 @@ struct InflatePadAndShuffle { template void SpatialConvolutionFunc(const Device& d, Output output, Input input, Filter filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Eigen::PaddingType& padding) { // Need to swap row/col when calling Eigen. output.device(d) = - Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding); + Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding, + col_dilation, row_dilation); } template @@ -65,9 +67,10 @@ struct SpatialConvolution { void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding) { + int col_stride, int row_dilation, int col_dilation, + const Eigen::PaddingType& padding) { SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, - padding); + row_dilation, col_dilation, padding); } }; @@ -77,11 +80,12 @@ struct SpatialConvolution { typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, - int row_stride, int col_stride, - const Eigen::PaddingType& padding) { + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding) { output.device(d) = Eigen::SpatialConvolution(input.cast(), filter.cast(), - col_stride, row_stride, padding) + col_stride, row_stride, padding, col_dilation, + row_dilation) .cast(); } }; @@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput { void operator()(const Device& d, typename TTypes::Tensor input_backward, typename TTypes::ConstTensor kernel, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( kernel, output_backward, input_backward.dimension(2), - input_backward.dimension(1), col_stride, row_stride); + input_backward.dimension(1), col_stride, row_stride, col_dilation, + row_dilation); } }; @@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter { typename TTypes::Tensor kernel_backward, typename TTypes::ConstTensor input, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, kernel_backward.dimension(1), - kernel_backward.dimension(0), col_stride, row_stride); + kernel_backward.dimension(0), col_stride, row_stride, col_dilation, + row_dilation); } }; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 512bcc6c01b..b8a5ae6a08e 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardFilter()( d, filter_backprop->tensor(), input.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 0356ff4c0f4..b87c7899c00 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardInput()( d, in_backprop->tensor(), filter.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index dbddaf3dc64..2b81e14f95b 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -60,8 +60,8 @@ template struct LaunchGeneric { void operator()(OpKernelContext* ctx, const Tensor& input, const Tensor& filter, int row_stride, int col_stride, - const Padding& padding, Tensor* output, - TensorFormat data_format) { + int row_dilation, int col_dilation, const Padding& padding, + Tensor* output, TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && @@ -86,7 +86,8 @@ struct LaunchGeneric { filter.shaped({filter.dim_size(2), filter.dim_size(3)}), dim_pair); } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && padding == VALID) { + filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + col_dilation == 1 && padding == VALID) { // If the input data and filter have the same height/width, // the 2D convolution is reduced to matrix multiplication. const int k = // Length of reduction dimension. @@ -103,7 +104,7 @@ struct LaunchGeneric { functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride, col_stride, - BrainPadding2EigenPadding(padding)); + row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); } } }; @@ -122,15 +123,9 @@ struct LaunchConv2DOp { "NHWC tensor format for now.")); return; } - // TODO(yangzihao): Add the CPU implementation of dilated conv 2D. - if (row_dilation > 1 || col_dilation > 1) { - ctx->SetStatus( - errors::Unimplemented("Generic conv implementation only supports " - "dilated rate of 1 for now.")); - return; - } LaunchGeneric()(ctx, input, filter, row_stride, col_stride, - padding, output, data_format); + row_dilation, col_dilation, padding, output, + data_format); } }; @@ -792,7 +787,8 @@ namespace functor { const GPUDevice& d, typename TTypes::Tensor output, \ typename TTypes::ConstTensor input, \ typename TTypes::ConstTensor filter, int row_stride, \ - int col_stride, const Eigen::PaddingType& padding); \ + int col_stride, int row_dilation, int col_dilation, \ + const Eigen::PaddingType& padding); \ extern template struct SpatialConvolution; \ template <> \ void MatMulConvFunctor::operator()( \ diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index edfb20d6a2b..27857989167 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase): padding, dilations): expected_results = [] computed_results = [] - default_dilations = (dilations[0] == 1 and dilations[1] == 1) for data_format, use_gpu in GetTestConfigs(): - # If any dilation rate is larger than 1, only do test on the GPU - # because we currently do not have a CPU implementation for arbitrary - # dilation rates. - if default_dilations or use_gpu: - expected, computed = self._ComputeReferenceDilatedConv( - tensor_in_sizes, filter_in_sizes, strides, dilations, padding, - data_format, use_gpu) - expected_results.append(expected) - computed_results.append(computed) - tolerance = 1e-2 if use_gpu else 1e-5 - expected_values = self.evaluate(expected_results) - computed_values = self.evaluate(computed_results) - for e_value, c_value in zip(expected_values, computed_values): - print("expected = ", e_value) - print("actual = ", c_value) - self.assertAllClose( - e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) + expected, computed = self._ComputeReferenceDilatedConv( + tensor_in_sizes, filter_in_sizes, strides, dilations, padding, + data_format, use_gpu) + expected_results.append(expected) + computed_results.append(computed) + tolerance = 1e-2 if use_gpu else 1e-5 + expected_values = self.evaluate(expected_results) + computed_values = self.evaluate(computed_results) + for e_value, c_value in zip(expected_values, computed_values): + print("expected = ", e_value) + print("actual = ", c_value) + self.assertAllClose( + e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): @@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter2x1Dilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2DEmpty(self): @@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DEmptyDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[0, 2, 3, 3], - filter_in_sizes=[1, 1, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter(self): @@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - strides=[1, 1], - dilations=[1, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D1x2Filter(self): @@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D1x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[1, 2, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2(self): @@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSizeDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 3, 3, 1], - filter_in_sizes=[2, 2, 1, 2], - strides=[1, 1], - dilations=[2, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 3, 3, 1], + filter_in_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID") # TODO(yzhwang): this currently fails. # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], @@ -1538,21 +1528,6 @@ class Conv2DTest(test.TestCase): use_gpu=False) self.evaluate(conv) - def testCPUConv2DDilatedUnimplemented(self): - with self.test_session(use_gpu=False): - with self.assertRaisesRegexp(errors_impl.UnimplementedError, - "dilated rate of 1 for now"): - conv = self._SetupValuesForDevice( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - dilations=[2, 1], - strides=[1, 1], - padding="VALID", - data_format="NHWC", - dtype=dtypes.float32, - use_gpu=False) - self.evaluate(conv) - class DepthwiseConv2DTest(test.TestCase): @@ -1887,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding): def Test(self): - if test.is_gpu_available(cuda_only=True) and stride == 1: + if stride == 1: tf_logging.info("Testing InceptionFwd with dilations %s", (input_size, filter_size, stride, padding)) self._VerifyDilatedConvValues(