Activates Eigen path for CPU implementation of atrous/dilated convolution (only forward path).

PiperOrigin-RevId: 186071285
This commit is contained in:
A. Unique TensorFlower 2018-02-16 17:56:36 -08:00 committed by TensorFlower Gardener
parent 0c14cf398c
commit a189502cc3
5 changed files with 75 additions and 94 deletions

View File

@ -54,10 +54,12 @@ struct InflatePadAndShuffle {
template <typename Device, typename Input, typename Filter, typename Output>
void SpatialConvolutionFunc(const Device& d, Output output, Input input,
Filter filter, int row_stride, int col_stride,
int row_dilation, int col_dilation,
const Eigen::PaddingType& padding) {
// Need to swap row/col when calling Eigen.
output.device(d) =
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding);
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
col_dilation, row_dilation);
}
template <typename Device, typename T>
@ -65,9 +67,10 @@ struct SpatialConvolution {
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
int col_stride, const Eigen::PaddingType& padding) {
int col_stride, int row_dilation, int col_dilation,
const Eigen::PaddingType& padding) {
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
padding);
row_dilation, col_dilation, padding);
}
};
@ -77,11 +80,12 @@ struct SpatialConvolution<Device, Eigen::half> {
typename TTypes<Eigen::half, 4>::Tensor output,
typename TTypes<Eigen::half, 4>::ConstTensor input,
typename TTypes<Eigen::half, 4>::ConstTensor filter,
int row_stride, int col_stride,
const Eigen::PaddingType& padding) {
int row_stride, int col_stride, int row_dilation,
int col_dilation, const Eigen::PaddingType& padding) {
output.device(d) =
Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
col_stride, row_stride, padding)
col_stride, row_stride, padding, col_dilation,
row_dilation)
.cast<Eigen::half>();
}
};
@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput {
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
typename TTypes<T, 4>::ConstTensor kernel,
typename TTypes<T, 4>::ConstTensor output_backward,
int row_stride, int col_stride) {
int row_stride, int col_stride, int row_dilation,
int col_dilation) {
// Need to swap row/col when calling Eigen.
input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
kernel, output_backward, input_backward.dimension(2),
input_backward.dimension(1), col_stride, row_stride);
input_backward.dimension(1), col_stride, row_stride, col_dilation,
row_dilation);
}
};
@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter {
typename TTypes<T, 4>::Tensor kernel_backward,
typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor output_backward,
int row_stride, int col_stride) {
int row_stride, int col_stride, int row_dilation,
int col_dilation) {
// Need to swap row/col when calling Eigen.
kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
input, output_backward, kernel_backward.dimension(1),
kernel_backward.dimension(0), col_stride, row_stride);
kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
row_dilation);
}
};

View File

@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
out_backprop.tensor<T, 4>(), row_stride, col_stride);
out_backprop.tensor<T, 4>(), row_stride, col_stride,
/*row_dilation=*/1, /*col_dilation=*/1);
}
};

View File

@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
out_backprop.tensor<T, 4>(), row_stride, col_stride);
out_backprop.tensor<T, 4>(), row_stride, col_stride,
/*row_dilation=*/1, /*col_dilation=*/1);
}
};

View File

@ -60,8 +60,8 @@ template <typename Device, typename T>
struct LaunchGeneric {
void operator()(OpKernelContext* ctx, const Tensor& input,
const Tensor& filter, int row_stride, int col_stride,
const Padding& padding, Tensor* output,
TensorFormat data_format) {
int row_dilation, int col_dilation, const Padding& padding,
Tensor* output, TensorFormat data_format) {
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
"supports NHWC tensor format for now.";
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
@ -86,7 +86,8 @@ struct LaunchGeneric {
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
dim_pair);
} else if (filter.dim_size(0) == input.dim_size(1) &&
filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
col_dilation == 1 && padding == VALID) {
// If the input data and filter have the same height/width,
// the 2D convolution is reduced to matrix multiplication.
const int k = // Length of reduction dimension.
@ -103,7 +104,7 @@ struct LaunchGeneric {
functor::SpatialConvolution<Device, T>()(
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
BrainPadding2EigenPadding(padding));
row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
}
}
};
@ -122,15 +123,9 @@ struct LaunchConv2DOp<CPUDevice, T> {
"NHWC tensor format for now."));
return;
}
// TODO(yangzihao): Add the CPU implementation of dilated conv 2D.
if (row_dilation > 1 || col_dilation > 1) {
ctx->SetStatus(
errors::Unimplemented("Generic conv implementation only supports "
"dilated rate of 1 for now."));
return;
}
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
padding, output, data_format);
row_dilation, col_dilation, padding, output,
data_format);
}
};
@ -792,7 +787,8 @@ namespace functor {
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
typename TTypes<T, 4>::ConstTensor input, \
typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
int col_stride, const Eigen::PaddingType& padding); \
int col_stride, int row_dilation, int col_dilation, \
const Eigen::PaddingType& padding); \
extern template struct SpatialConvolution<GPUDevice, T>; \
template <> \
void MatMulConvFunctor<GPUDevice, T>::operator()( \

View File

@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase):
padding, dilations):
expected_results = []
computed_results = []
default_dilations = (dilations[0] == 1 and dilations[1] == 1)
for data_format, use_gpu in GetTestConfigs():
# If any dilation rate is larger than 1, only do test on the GPU
# because we currently do not have a CPU implementation for arbitrary
# dilation rates.
if default_dilations or use_gpu:
expected, computed = self._ComputeReferenceDilatedConv(
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
data_format, use_gpu)
expected_results.append(expected)
computed_results.append(computed)
tolerance = 1e-2 if use_gpu else 1e-5
expected_values = self.evaluate(expected_results)
computed_values = self.evaluate(computed_results)
for e_value, c_value in zip(expected_values, computed_values):
print("expected = ", e_value)
print("actual = ", c_value)
self.assertAllClose(
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
expected, computed = self._ComputeReferenceDilatedConv(
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
data_format, use_gpu)
expected_results.append(expected)
computed_results.append(computed)
tolerance = 1e-2 if use_gpu else 1e-5
expected_values = self.evaluate(expected_results)
computed_values = self.evaluate(computed_results)
for e_value, c_value in zip(expected_values, computed_values):
print("expected = ", e_value)
print("actual = ", c_value)
self.assertAllClose(
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
expected):
@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Filter2x1Dilation(self):
if test.is_gpu_available(cuda_only=True):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 4, 4, 1],
filter_in_sizes=[2, 2, 1, 1],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 4, 4, 1],
filter_in_sizes=[2, 2, 1, 1],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
@test_util.run_in_graph_and_eager_modes()
def testConv2DEmpty(self):
@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testConv2DEmptyDilation(self):
if test.is_gpu_available(cuda_only=True):
self._VerifyDilatedConvValues(
tensor_in_sizes=[0, 2, 3, 3],
filter_in_sizes=[1, 1, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
self._VerifyDilatedConvValues(
tensor_in_sizes=[0, 2, 3, 3],
filter_in_sizes=[1, 1, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Filter(self):
@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterDilation(self):
if test.is_gpu_available(cuda_only=True):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
filter_in_sizes=[2, 2, 3, 3],
strides=[1, 1],
dilations=[1, 2],
padding="VALID")
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
filter_in_sizes=[2, 2, 3, 3],
strides=[1, 1],
dilations=[1, 2],
padding="VALID")
@test_util.run_in_graph_and_eager_modes()
def testConv2D1x2Filter(self):
@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testConv2D1x2FilterDilation(self):
if test.is_gpu_available(cuda_only=True):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
filter_in_sizes=[1, 2, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
filter_in_sizes=[1, 2, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID")
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterStride2(self):
@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSizeMatchesInputSizeDilation(self):
if test.is_gpu_available(cuda_only=True):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 3, 3, 1],
filter_in_sizes=[2, 2, 1, 2],
strides=[1, 1],
dilations=[2, 2],
padding="VALID")
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 3, 3, 1],
filter_in_sizes=[2, 2, 1, 2],
strides=[1, 1],
dilations=[2, 2],
padding="VALID")
# TODO(yzhwang): this currently fails.
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
@ -1538,21 +1528,6 @@ class Conv2DTest(test.TestCase):
use_gpu=False)
self.evaluate(conv)
def testCPUConv2DDilatedUnimplemented(self):
with self.test_session(use_gpu=False):
with self.assertRaisesRegexp(errors_impl.UnimplementedError,
"dilated rate of 1 for now"):
conv = self._SetupValuesForDevice(
tensor_in_sizes=[1, 4, 4, 1],
filter_in_sizes=[2, 2, 1, 1],
dilations=[2, 1],
strides=[1, 1],
padding="VALID",
data_format="NHWC",
dtype=dtypes.float32,
use_gpu=False)
self.evaluate(conv)
class DepthwiseConv2DTest(test.TestCase):
@ -1887,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
def Test(self):
if test.is_gpu_available(cuda_only=True) and stride == 1:
if stride == 1:
tf_logging.info("Testing InceptionFwd with dilations %s",
(input_size, filter_size, stride, padding))
self._VerifyDilatedConvValues(