Activates Eigen path for CPU implementation of atrous/dilated convolution (only forward path).
PiperOrigin-RevId: 186071285
This commit is contained in:
parent
0c14cf398c
commit
a189502cc3
@ -54,10 +54,12 @@ struct InflatePadAndShuffle {
|
||||
template <typename Device, typename Input, typename Filter, typename Output>
|
||||
void SpatialConvolutionFunc(const Device& d, Output output, Input input,
|
||||
Filter filter, int row_stride, int col_stride,
|
||||
int row_dilation, int col_dilation,
|
||||
const Eigen::PaddingType& padding) {
|
||||
// Need to swap row/col when calling Eigen.
|
||||
output.device(d) =
|
||||
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding);
|
||||
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
|
||||
col_dilation, row_dilation);
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -65,9 +67,10 @@ struct SpatialConvolution {
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
|
||||
typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
|
||||
int col_stride, const Eigen::PaddingType& padding) {
|
||||
int col_stride, int row_dilation, int col_dilation,
|
||||
const Eigen::PaddingType& padding) {
|
||||
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
||||
padding);
|
||||
row_dilation, col_dilation, padding);
|
||||
}
|
||||
};
|
||||
|
||||
@ -77,11 +80,12 @@ struct SpatialConvolution<Device, Eigen::half> {
|
||||
typename TTypes<Eigen::half, 4>::Tensor output,
|
||||
typename TTypes<Eigen::half, 4>::ConstTensor input,
|
||||
typename TTypes<Eigen::half, 4>::ConstTensor filter,
|
||||
int row_stride, int col_stride,
|
||||
const Eigen::PaddingType& padding) {
|
||||
int row_stride, int col_stride, int row_dilation,
|
||||
int col_dilation, const Eigen::PaddingType& padding) {
|
||||
output.device(d) =
|
||||
Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
|
||||
col_stride, row_stride, padding)
|
||||
col_stride, row_stride, padding, col_dilation,
|
||||
row_dilation)
|
||||
.cast<Eigen::half>();
|
||||
}
|
||||
};
|
||||
@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput {
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
|
||||
typename TTypes<T, 4>::ConstTensor kernel,
|
||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||
int row_stride, int col_stride) {
|
||||
int row_stride, int col_stride, int row_dilation,
|
||||
int col_dilation) {
|
||||
// Need to swap row/col when calling Eigen.
|
||||
input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
|
||||
kernel, output_backward, input_backward.dimension(2),
|
||||
input_backward.dimension(1), col_stride, row_stride);
|
||||
input_backward.dimension(1), col_stride, row_stride, col_dilation,
|
||||
row_dilation);
|
||||
}
|
||||
};
|
||||
|
||||
@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter {
|
||||
typename TTypes<T, 4>::Tensor kernel_backward,
|
||||
typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||
int row_stride, int col_stride) {
|
||||
int row_stride, int col_stride, int row_dilation,
|
||||
int col_dilation) {
|
||||
// Need to swap row/col when calling Eigen.
|
||||
kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
|
||||
input, output_backward, kernel_backward.dimension(1),
|
||||
kernel_backward.dimension(0), col_stride, row_stride);
|
||||
kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
|
||||
row_dilation);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
|
||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
|
||||
d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), row_stride, col_stride);
|
||||
out_backprop.tensor<T, 4>(), row_stride, col_stride,
|
||||
/*row_dilation=*/1, /*col_dilation=*/1);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
|
||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
|
||||
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), row_stride, col_stride);
|
||||
out_backprop.tensor<T, 4>(), row_stride, col_stride,
|
||||
/*row_dilation=*/1, /*col_dilation=*/1);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -60,8 +60,8 @@ template <typename Device, typename T>
|
||||
struct LaunchGeneric {
|
||||
void operator()(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int row_stride, int col_stride,
|
||||
const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
int row_dilation, int col_dilation, const Padding& padding,
|
||||
Tensor* output, TensorFormat data_format) {
|
||||
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
|
||||
"supports NHWC tensor format for now.";
|
||||
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
|
||||
@ -86,7 +86,8 @@ struct LaunchGeneric {
|
||||
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
|
||||
dim_pair);
|
||||
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
||||
filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
|
||||
filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
|
||||
col_dilation == 1 && padding == VALID) {
|
||||
// If the input data and filter have the same height/width,
|
||||
// the 2D convolution is reduced to matrix multiplication.
|
||||
const int k = // Length of reduction dimension.
|
||||
@ -103,7 +104,7 @@ struct LaunchGeneric {
|
||||
functor::SpatialConvolution<Device, T>()(
|
||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||
BrainPadding2EigenPadding(padding));
|
||||
row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -122,15 +123,9 @@ struct LaunchConv2DOp<CPUDevice, T> {
|
||||
"NHWC tensor format for now."));
|
||||
return;
|
||||
}
|
||||
// TODO(yangzihao): Add the CPU implementation of dilated conv 2D.
|
||||
if (row_dilation > 1 || col_dilation > 1) {
|
||||
ctx->SetStatus(
|
||||
errors::Unimplemented("Generic conv implementation only supports "
|
||||
"dilated rate of 1 for now."));
|
||||
return;
|
||||
}
|
||||
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
||||
padding, output, data_format);
|
||||
row_dilation, col_dilation, padding, output,
|
||||
data_format);
|
||||
}
|
||||
};
|
||||
|
||||
@ -792,7 +787,8 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
|
||||
typename TTypes<T, 4>::ConstTensor input, \
|
||||
typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
|
||||
int col_stride, const Eigen::PaddingType& padding); \
|
||||
int col_stride, int row_dilation, int col_dilation, \
|
||||
const Eigen::PaddingType& padding); \
|
||||
extern template struct SpatialConvolution<GPUDevice, T>; \
|
||||
template <> \
|
||||
void MatMulConvFunctor<GPUDevice, T>::operator()( \
|
||||
|
@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase):
|
||||
padding, dilations):
|
||||
expected_results = []
|
||||
computed_results = []
|
||||
default_dilations = (dilations[0] == 1 and dilations[1] == 1)
|
||||
for data_format, use_gpu in GetTestConfigs():
|
||||
# If any dilation rate is larger than 1, only do test on the GPU
|
||||
# because we currently do not have a CPU implementation for arbitrary
|
||||
# dilation rates.
|
||||
if default_dilations or use_gpu:
|
||||
expected, computed = self._ComputeReferenceDilatedConv(
|
||||
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
|
||||
data_format, use_gpu)
|
||||
expected_results.append(expected)
|
||||
computed_results.append(computed)
|
||||
tolerance = 1e-2 if use_gpu else 1e-5
|
||||
expected_values = self.evaluate(expected_results)
|
||||
computed_values = self.evaluate(computed_results)
|
||||
for e_value, c_value in zip(expected_values, computed_values):
|
||||
print("expected = ", e_value)
|
||||
print("actual = ", c_value)
|
||||
self.assertAllClose(
|
||||
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
|
||||
expected, computed = self._ComputeReferenceDilatedConv(
|
||||
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
|
||||
data_format, use_gpu)
|
||||
expected_results.append(expected)
|
||||
computed_results.append(computed)
|
||||
tolerance = 1e-2 if use_gpu else 1e-5
|
||||
expected_values = self.evaluate(expected_results)
|
||||
computed_values = self.evaluate(computed_results)
|
||||
for e_value, c_value in zip(expected_values, computed_values):
|
||||
print("expected = ", e_value)
|
||||
print("actual = ", c_value)
|
||||
self.assertAllClose(
|
||||
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
|
||||
|
||||
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
|
||||
expected):
|
||||
@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Filter2x1Dilation(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 4, 4, 1],
|
||||
filter_in_sizes=[2, 2, 1, 1],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 4, 4, 1],
|
||||
filter_in_sizes=[2, 2, 1, 1],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2DEmpty(self):
|
||||
@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2DEmptyDilation(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[0, 2, 3, 3],
|
||||
filter_in_sizes=[1, 1, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[0, 2, 3, 3],
|
||||
filter_in_sizes=[1, 1, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Filter(self):
|
||||
@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2FilterDilation(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
filter_in_sizes=[2, 2, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[1, 2],
|
||||
padding="VALID")
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
filter_in_sizes=[2, 2, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[1, 2],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D1x2Filter(self):
|
||||
@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D1x2FilterDilation(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
filter_in_sizes=[1, 2, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 2, 3, 3],
|
||||
filter_in_sizes=[1, 2, 3, 3],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 1],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2FilterStride2(self):
|
||||
@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2DKernelSizeMatchesInputSizeDilation(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 3, 3, 1],
|
||||
filter_in_sizes=[2, 2, 1, 2],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 2],
|
||||
padding="VALID")
|
||||
self._VerifyDilatedConvValues(
|
||||
tensor_in_sizes=[1, 3, 3, 1],
|
||||
filter_in_sizes=[2, 2, 1, 2],
|
||||
strides=[1, 1],
|
||||
dilations=[2, 2],
|
||||
padding="VALID")
|
||||
|
||||
# TODO(yzhwang): this currently fails.
|
||||
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
|
||||
@ -1538,21 +1528,6 @@ class Conv2DTest(test.TestCase):
|
||||
use_gpu=False)
|
||||
self.evaluate(conv)
|
||||
|
||||
def testCPUConv2DDilatedUnimplemented(self):
|
||||
with self.test_session(use_gpu=False):
|
||||
with self.assertRaisesRegexp(errors_impl.UnimplementedError,
|
||||
"dilated rate of 1 for now"):
|
||||
conv = self._SetupValuesForDevice(
|
||||
tensor_in_sizes=[1, 4, 4, 1],
|
||||
filter_in_sizes=[2, 2, 1, 1],
|
||||
dilations=[2, 1],
|
||||
strides=[1, 1],
|
||||
padding="VALID",
|
||||
data_format="NHWC",
|
||||
dtype=dtypes.float32,
|
||||
use_gpu=False)
|
||||
self.evaluate(conv)
|
||||
|
||||
|
||||
class DepthwiseConv2DTest(test.TestCase):
|
||||
|
||||
@ -1887,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
|
||||
def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
|
||||
|
||||
def Test(self):
|
||||
if test.is_gpu_available(cuda_only=True) and stride == 1:
|
||||
if stride == 1:
|
||||
tf_logging.info("Testing InceptionFwd with dilations %s",
|
||||
(input_size, filter_size, stride, padding))
|
||||
self._VerifyDilatedConvValues(
|
||||
|
Loading…
x
Reference in New Issue
Block a user