Activates Eigen path for CPU implementation of atrous/dilated convolution (only forward path).
PiperOrigin-RevId: 186071285
This commit is contained in:
parent
0c14cf398c
commit
a189502cc3
@ -54,10 +54,12 @@ struct InflatePadAndShuffle {
|
|||||||
template <typename Device, typename Input, typename Filter, typename Output>
|
template <typename Device, typename Input, typename Filter, typename Output>
|
||||||
void SpatialConvolutionFunc(const Device& d, Output output, Input input,
|
void SpatialConvolutionFunc(const Device& d, Output output, Input input,
|
||||||
Filter filter, int row_stride, int col_stride,
|
Filter filter, int row_stride, int col_stride,
|
||||||
|
int row_dilation, int col_dilation,
|
||||||
const Eigen::PaddingType& padding) {
|
const Eigen::PaddingType& padding) {
|
||||||
// Need to swap row/col when calling Eigen.
|
// Need to swap row/col when calling Eigen.
|
||||||
output.device(d) =
|
output.device(d) =
|
||||||
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding);
|
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
|
||||||
|
col_dilation, row_dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
@ -65,9 +67,10 @@ struct SpatialConvolution {
|
|||||||
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
|
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
|
||||||
typename TTypes<T, 4>::ConstTensor input,
|
typename TTypes<T, 4>::ConstTensor input,
|
||||||
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
|
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
|
||||||
int col_stride, const Eigen::PaddingType& padding) {
|
int col_stride, int row_dilation, int col_dilation,
|
||||||
|
const Eigen::PaddingType& padding) {
|
||||||
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
||||||
padding);
|
row_dilation, col_dilation, padding);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -77,11 +80,12 @@ struct SpatialConvolution<Device, Eigen::half> {
|
|||||||
typename TTypes<Eigen::half, 4>::Tensor output,
|
typename TTypes<Eigen::half, 4>::Tensor output,
|
||||||
typename TTypes<Eigen::half, 4>::ConstTensor input,
|
typename TTypes<Eigen::half, 4>::ConstTensor input,
|
||||||
typename TTypes<Eigen::half, 4>::ConstTensor filter,
|
typename TTypes<Eigen::half, 4>::ConstTensor filter,
|
||||||
int row_stride, int col_stride,
|
int row_stride, int col_stride, int row_dilation,
|
||||||
const Eigen::PaddingType& padding) {
|
int col_dilation, const Eigen::PaddingType& padding) {
|
||||||
output.device(d) =
|
output.device(d) =
|
||||||
Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
|
Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
|
||||||
col_stride, row_stride, padding)
|
col_stride, row_stride, padding, col_dilation,
|
||||||
|
row_dilation)
|
||||||
.cast<Eigen::half>();
|
.cast<Eigen::half>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput {
|
|||||||
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
|
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
|
||||||
typename TTypes<T, 4>::ConstTensor kernel,
|
typename TTypes<T, 4>::ConstTensor kernel,
|
||||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||||
int row_stride, int col_stride) {
|
int row_stride, int col_stride, int row_dilation,
|
||||||
|
int col_dilation) {
|
||||||
// Need to swap row/col when calling Eigen.
|
// Need to swap row/col when calling Eigen.
|
||||||
input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
|
input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
|
||||||
kernel, output_backward, input_backward.dimension(2),
|
kernel, output_backward, input_backward.dimension(2),
|
||||||
input_backward.dimension(1), col_stride, row_stride);
|
input_backward.dimension(1), col_stride, row_stride, col_dilation,
|
||||||
|
row_dilation);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter {
|
|||||||
typename TTypes<T, 4>::Tensor kernel_backward,
|
typename TTypes<T, 4>::Tensor kernel_backward,
|
||||||
typename TTypes<T, 4>::ConstTensor input,
|
typename TTypes<T, 4>::ConstTensor input,
|
||||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||||
int row_stride, int col_stride) {
|
int row_stride, int col_stride, int row_dilation,
|
||||||
|
int col_dilation) {
|
||||||
// Need to swap row/col when calling Eigen.
|
// Need to swap row/col when calling Eigen.
|
||||||
kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
|
kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
|
||||||
input, output_backward, kernel_backward.dimension(1),
|
input, output_backward, kernel_backward.dimension(1),
|
||||||
kernel_backward.dimension(0), col_stride, row_stride);
|
kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
|
||||||
|
row_dilation);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
|
|||||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||||
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
|
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
|
||||||
d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
|
d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
|
||||||
out_backprop.tensor<T, 4>(), row_stride, col_stride);
|
out_backprop.tensor<T, 4>(), row_stride, col_stride,
|
||||||
|
/*row_dilation=*/1, /*col_dilation=*/1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
|
|||||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||||
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
|
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
|
||||||
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||||
out_backprop.tensor<T, 4>(), row_stride, col_stride);
|
out_backprop.tensor<T, 4>(), row_stride, col_stride,
|
||||||
|
/*row_dilation=*/1, /*col_dilation=*/1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -60,8 +60,8 @@ template <typename Device, typename T>
|
|||||||
struct LaunchGeneric {
|
struct LaunchGeneric {
|
||||||
void operator()(OpKernelContext* ctx, const Tensor& input,
|
void operator()(OpKernelContext* ctx, const Tensor& input,
|
||||||
const Tensor& filter, int row_stride, int col_stride,
|
const Tensor& filter, int row_stride, int col_stride,
|
||||||
const Padding& padding, Tensor* output,
|
int row_dilation, int col_dilation, const Padding& padding,
|
||||||
TensorFormat data_format) {
|
Tensor* output, TensorFormat data_format) {
|
||||||
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
|
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
|
||||||
"supports NHWC tensor format for now.";
|
"supports NHWC tensor format for now.";
|
||||||
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
|
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
|
||||||
@ -86,7 +86,8 @@ struct LaunchGeneric {
|
|||||||
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
|
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
|
||||||
dim_pair);
|
dim_pair);
|
||||||
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
||||||
filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
|
filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
|
||||||
|
col_dilation == 1 && padding == VALID) {
|
||||||
// If the input data and filter have the same height/width,
|
// If the input data and filter have the same height/width,
|
||||||
// the 2D convolution is reduced to matrix multiplication.
|
// the 2D convolution is reduced to matrix multiplication.
|
||||||
const int k = // Length of reduction dimension.
|
const int k = // Length of reduction dimension.
|
||||||
@ -103,7 +104,7 @@ struct LaunchGeneric {
|
|||||||
functor::SpatialConvolution<Device, T>()(
|
functor::SpatialConvolution<Device, T>()(
|
||||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||||
BrainPadding2EigenPadding(padding));
|
row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -122,15 +123,9 @@ struct LaunchConv2DOp<CPUDevice, T> {
|
|||||||
"NHWC tensor format for now."));
|
"NHWC tensor format for now."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO(yangzihao): Add the CPU implementation of dilated conv 2D.
|
|
||||||
if (row_dilation > 1 || col_dilation > 1) {
|
|
||||||
ctx->SetStatus(
|
|
||||||
errors::Unimplemented("Generic conv implementation only supports "
|
|
||||||
"dilated rate of 1 for now."));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
||||||
padding, output, data_format);
|
row_dilation, col_dilation, padding, output,
|
||||||
|
data_format);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -792,7 +787,8 @@ namespace functor {
|
|||||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
|
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
|
||||||
typename TTypes<T, 4>::ConstTensor input, \
|
typename TTypes<T, 4>::ConstTensor input, \
|
||||||
typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
|
typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
|
||||||
int col_stride, const Eigen::PaddingType& padding); \
|
int col_stride, int row_dilation, int col_dilation, \
|
||||||
|
const Eigen::PaddingType& padding); \
|
||||||
extern template struct SpatialConvolution<GPUDevice, T>; \
|
extern template struct SpatialConvolution<GPUDevice, T>; \
|
||||||
template <> \
|
template <> \
|
||||||
void MatMulConvFunctor<GPUDevice, T>::operator()( \
|
void MatMulConvFunctor<GPUDevice, T>::operator()( \
|
||||||
|
|||||||
@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase):
|
|||||||
padding, dilations):
|
padding, dilations):
|
||||||
expected_results = []
|
expected_results = []
|
||||||
computed_results = []
|
computed_results = []
|
||||||
default_dilations = (dilations[0] == 1 and dilations[1] == 1)
|
|
||||||
for data_format, use_gpu in GetTestConfigs():
|
for data_format, use_gpu in GetTestConfigs():
|
||||||
# If any dilation rate is larger than 1, only do test on the GPU
|
expected, computed = self._ComputeReferenceDilatedConv(
|
||||||
# because we currently do not have a CPU implementation for arbitrary
|
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
|
||||||
# dilation rates.
|
data_format, use_gpu)
|
||||||
if default_dilations or use_gpu:
|
expected_results.append(expected)
|
||||||
expected, computed = self._ComputeReferenceDilatedConv(
|
computed_results.append(computed)
|
||||||
tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
|
tolerance = 1e-2 if use_gpu else 1e-5
|
||||||
data_format, use_gpu)
|
expected_values = self.evaluate(expected_results)
|
||||||
expected_results.append(expected)
|
computed_values = self.evaluate(computed_results)
|
||||||
computed_results.append(computed)
|
for e_value, c_value in zip(expected_values, computed_values):
|
||||||
tolerance = 1e-2 if use_gpu else 1e-5
|
print("expected = ", e_value)
|
||||||
expected_values = self.evaluate(expected_results)
|
print("actual = ", c_value)
|
||||||
computed_values = self.evaluate(computed_results)
|
self.assertAllClose(
|
||||||
for e_value, c_value in zip(expected_values, computed_values):
|
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
|
||||||
print("expected = ", e_value)
|
|
||||||
print("actual = ", c_value)
|
|
||||||
self.assertAllClose(
|
|
||||||
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
|
|
||||||
|
|
||||||
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
|
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
|
||||||
expected):
|
expected):
|
||||||
@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D2x2Filter2x1Dilation(self):
|
def testConv2D2x2Filter2x1Dilation(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._VerifyDilatedConvValues(
|
||||||
self._VerifyDilatedConvValues(
|
tensor_in_sizes=[1, 4, 4, 1],
|
||||||
tensor_in_sizes=[1, 4, 4, 1],
|
filter_in_sizes=[2, 2, 1, 1],
|
||||||
filter_in_sizes=[2, 2, 1, 1],
|
strides=[1, 1],
|
||||||
strides=[1, 1],
|
dilations=[2, 1],
|
||||||
dilations=[2, 1],
|
padding="VALID")
|
||||||
padding="VALID")
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2DEmpty(self):
|
def testConv2DEmpty(self):
|
||||||
@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2DEmptyDilation(self):
|
def testConv2DEmptyDilation(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._VerifyDilatedConvValues(
|
||||||
self._VerifyDilatedConvValues(
|
tensor_in_sizes=[0, 2, 3, 3],
|
||||||
tensor_in_sizes=[0, 2, 3, 3],
|
filter_in_sizes=[1, 1, 3, 3],
|
||||||
filter_in_sizes=[1, 1, 3, 3],
|
strides=[1, 1],
|
||||||
strides=[1, 1],
|
dilations=[2, 1],
|
||||||
dilations=[2, 1],
|
padding="VALID")
|
||||||
padding="VALID")
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D2x2Filter(self):
|
def testConv2D2x2Filter(self):
|
||||||
@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D2x2FilterDilation(self):
|
def testConv2D2x2FilterDilation(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._VerifyDilatedConvValues(
|
||||||
self._VerifyDilatedConvValues(
|
tensor_in_sizes=[1, 2, 3, 3],
|
||||||
tensor_in_sizes=[1, 2, 3, 3],
|
filter_in_sizes=[2, 2, 3, 3],
|
||||||
filter_in_sizes=[2, 2, 3, 3],
|
strides=[1, 1],
|
||||||
strides=[1, 1],
|
dilations=[1, 2],
|
||||||
dilations=[1, 2],
|
padding="VALID")
|
||||||
padding="VALID")
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D1x2Filter(self):
|
def testConv2D1x2Filter(self):
|
||||||
@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D1x2FilterDilation(self):
|
def testConv2D1x2FilterDilation(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._VerifyDilatedConvValues(
|
||||||
self._VerifyDilatedConvValues(
|
tensor_in_sizes=[1, 2, 3, 3],
|
||||||
tensor_in_sizes=[1, 2, 3, 3],
|
filter_in_sizes=[1, 2, 3, 3],
|
||||||
filter_in_sizes=[1, 2, 3, 3],
|
strides=[1, 1],
|
||||||
strides=[1, 1],
|
dilations=[2, 1],
|
||||||
dilations=[2, 1],
|
padding="VALID")
|
||||||
padding="VALID")
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D2x2FilterStride2(self):
|
def testConv2D2x2FilterStride2(self):
|
||||||
@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2DKernelSizeMatchesInputSizeDilation(self):
|
def testConv2DKernelSizeMatchesInputSizeDilation(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._VerifyDilatedConvValues(
|
||||||
self._VerifyDilatedConvValues(
|
tensor_in_sizes=[1, 3, 3, 1],
|
||||||
tensor_in_sizes=[1, 3, 3, 1],
|
filter_in_sizes=[2, 2, 1, 2],
|
||||||
filter_in_sizes=[2, 2, 1, 2],
|
strides=[1, 1],
|
||||||
strides=[1, 1],
|
dilations=[2, 2],
|
||||||
dilations=[2, 2],
|
padding="VALID")
|
||||||
padding="VALID")
|
|
||||||
|
|
||||||
# TODO(yzhwang): this currently fails.
|
# TODO(yzhwang): this currently fails.
|
||||||
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
|
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
|
||||||
@ -1538,21 +1528,6 @@ class Conv2DTest(test.TestCase):
|
|||||||
use_gpu=False)
|
use_gpu=False)
|
||||||
self.evaluate(conv)
|
self.evaluate(conv)
|
||||||
|
|
||||||
def testCPUConv2DDilatedUnimplemented(self):
|
|
||||||
with self.test_session(use_gpu=False):
|
|
||||||
with self.assertRaisesRegexp(errors_impl.UnimplementedError,
|
|
||||||
"dilated rate of 1 for now"):
|
|
||||||
conv = self._SetupValuesForDevice(
|
|
||||||
tensor_in_sizes=[1, 4, 4, 1],
|
|
||||||
filter_in_sizes=[2, 2, 1, 1],
|
|
||||||
dilations=[2, 1],
|
|
||||||
strides=[1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
data_format="NHWC",
|
|
||||||
dtype=dtypes.float32,
|
|
||||||
use_gpu=False)
|
|
||||||
self.evaluate(conv)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthwiseConv2DTest(test.TestCase):
|
class DepthwiseConv2DTest(test.TestCase):
|
||||||
|
|
||||||
@ -1887,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
|
|||||||
def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
|
def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
|
||||||
|
|
||||||
def Test(self):
|
def Test(self):
|
||||||
if test.is_gpu_available(cuda_only=True) and stride == 1:
|
if stride == 1:
|
||||||
tf_logging.info("Testing InceptionFwd with dilations %s",
|
tf_logging.info("Testing InceptionFwd with dilations %s",
|
||||||
(input_size, filter_size, stride, padding))
|
(input_size, filter_size, stride, padding))
|
||||||
self._VerifyDilatedConvValues(
|
self._VerifyDilatedConvValues(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user