Support explicit padding on CPU for tf.nn.conv2d.
PiperOrigin-RevId: 244421216
This commit is contained in:
parent
d276f27d75
commit
ee82131dbc
tensorflow
core/kernels
python
@ -57,11 +57,16 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input,
|
||||
Filter filter, int row_stride, int col_stride,
|
||||
int row_dilation, int col_dilation,
|
||||
const Eigen::PaddingType& padding,
|
||||
const OutputKernel& output_kernel) {
|
||||
// Need to swap row/col when calling Eigen.
|
||||
output.device(d) =
|
||||
Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
|
||||
col_dilation, row_dilation, output_kernel);
|
||||
const OutputKernel& output_kernel,
|
||||
int padding_top = 0, int padding_bottom = 0,
|
||||
int padding_left = 0, int padding_right = 0) {
|
||||
// Need to swap row/col, padding_top/padding_left, and
|
||||
// padding_bottom/padding_right when calling Eigen. Eigen expects the tensor
|
||||
// in NWHC format, but the tensor given is in NHWC.
|
||||
output.device(d) = Eigen::SpatialConvolution(
|
||||
input, filter, col_stride, row_stride, padding, col_dilation,
|
||||
row_dilation, output_kernel, padding_left, padding_right, padding_top,
|
||||
padding_bottom);
|
||||
}
|
||||
|
||||
template <typename Device, typename T,
|
||||
@ -76,6 +81,18 @@ struct SpatialConvolution {
|
||||
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
||||
row_dilation, col_dilation, padding, output_kernel);
|
||||
}
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
|
||||
typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
|
||||
int col_stride, int row_dilation, int col_dilation,
|
||||
int padding_top, int padding_bottom, int padding_left,
|
||||
int padding_right,
|
||||
const OutputKernel& output_kernel = OutputKernel()) {
|
||||
SpatialConvolutionFunc(
|
||||
d, output, input, filter, row_stride, col_stride, row_dilation,
|
||||
col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
|
||||
padding_top, padding_bottom, padding_left, padding_right);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename OutputKernel>
|
||||
@ -93,6 +110,22 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
|
||||
row_dilation, output_kernel)
|
||||
.template cast<Eigen::half>();
|
||||
}
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<Eigen::half, 4>::Tensor output,
|
||||
typename TTypes<Eigen::half, 4>::ConstTensor input,
|
||||
typename TTypes<Eigen::half, 4>::ConstTensor filter,
|
||||
int row_stride, int col_stride, int row_dilation,
|
||||
int col_dilation, int padding_top, int padding_bottom,
|
||||
int padding_left, int padding_right,
|
||||
const OutputKernel& output_kernel = OutputKernel()) {
|
||||
output.device(d) =
|
||||
Eigen::SpatialConvolution(
|
||||
input.cast<float>(), filter.cast<float>(), col_stride, row_stride,
|
||||
Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation,
|
||||
output_kernel, padding_left, padding_right, padding_top,
|
||||
padding_bottom)
|
||||
.template cast<Eigen::half>();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
@ -208,14 +208,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Row and column strides should be larger than 0."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
OP_REQUIRES(
|
||||
context, padding_ != Padding::EXPLICIT,
|
||||
errors::Unimplemented("Current CPU implementation does not support "
|
||||
"EXPLICIT padding yet."));
|
||||
std::vector<int64> explicit_paddings;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
|
||||
OP_REQUIRES(context, dilations_.size() == 4,
|
||||
@ -247,11 +242,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
filter_sizes.vec<int32>(), &filter_shape));
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(context,
|
||||
ConvBackpropComputeDimensions(
|
||||
"Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2,
|
||||
input.shape(), filter_shape, out_backprop.shape(),
|
||||
strides_, padding_, data_format_, &dims));
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
ConvBackpropComputeDimensionsV2(
|
||||
"Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
|
||||
filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1},
|
||||
strides_, padding_, explicit_paddings_, data_format_, &dims));
|
||||
|
||||
Tensor* filter_backprop;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -264,6 +260,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
|
||||
int64 pad_top, pad_bottom;
|
||||
int64 pad_left, pad_right;
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
pad_top = explicit_paddings_[2];
|
||||
pad_bottom = explicit_paddings_[3];
|
||||
pad_left = explicit_paddings_[4];
|
||||
pad_right = explicit_paddings_[5];
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
GetWindowedOutputSizeVerbose(
|
||||
@ -402,6 +404,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
std::vector<int32> dilations_;
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp);
|
||||
|
@ -299,14 +299,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Current libxsmm and customized CPU implementations do "
|
||||
"not yet support dilation rates larger than 1."));
|
||||
OP_REQUIRES(
|
||||
context, padding_ != Padding::EXPLICIT,
|
||||
errors::Unimplemented("Current CPU implementation does not support "
|
||||
"EXPLICIT padding yet."));
|
||||
std::vector<int64> explicit_paddings;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
}
|
||||
|
||||
@ -325,10 +320,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(context,
|
||||
ConvBackpropComputeDimensions(
|
||||
ConvBackpropComputeDimensionsV2(
|
||||
"Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
|
||||
input_shape, filter.shape(), out_backprop.shape(),
|
||||
strides_, padding_, data_format_, &dims));
|
||||
/*dilations=*/{1, 1, 1, 1}, strides_, padding_,
|
||||
explicit_paddings_, data_format_, &dims));
|
||||
|
||||
Tensor* in_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -375,6 +371,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
int64 pad_top, pad_bottom;
|
||||
int64 pad_left, pad_right;
|
||||
#endif
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
pad_top = explicit_paddings_[2];
|
||||
pad_bottom = explicit_paddings_[3];
|
||||
pad_left = explicit_paddings_[4];
|
||||
pad_right = explicit_paddings_[5];
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
GetWindowedOutputSizeVerbose(
|
||||
@ -536,6 +538,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
std::vector<int32> dilations_;
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
|
||||
@ -617,12 +620,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
if (!std::is_same<Device, GPUDevice>::value) {
|
||||
OP_REQUIRES(
|
||||
context, padding_ != Padding::EXPLICIT,
|
||||
errors::Unimplemented("Current CPU implementation does not support "
|
||||
"EXPLICIT padding yet."));
|
||||
}
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
|
@ -70,11 +70,12 @@ struct LaunchGeneric {
|
||||
void operator()(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int row_stride, int col_stride,
|
||||
int row_dilation, int col_dilation, const Padding& padding,
|
||||
Tensor* output, TensorFormat data_format) {
|
||||
const std::vector<int64>& explicit_paddings, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
|
||||
"supports NHWC tensor format for now.";
|
||||
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
|
||||
col_stride == 1) {
|
||||
col_stride == 1 && (padding == SAME || padding == VALID)) {
|
||||
// For 1x1 kernel, the 2D convolution is reduced to matrix
|
||||
// multiplication.
|
||||
//
|
||||
@ -110,10 +111,20 @@ struct LaunchGeneric {
|
||||
input.shaped<T, 2>({input.dim_size(0), k}),
|
||||
filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
|
||||
} else {
|
||||
functor::SpatialConvolution<Device, T>()(
|
||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||
row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
|
||||
if (padding == EXPLICIT) {
|
||||
functor::SpatialConvolution<Device, T>()(
|
||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||
row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
|
||||
static_cast<int>(explicit_paddings[3]),
|
||||
static_cast<int>(explicit_paddings[4]),
|
||||
static_cast<int>(explicit_paddings[5]));
|
||||
} else {
|
||||
functor::SpatialConvolution<Device, T>()(
|
||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||
row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -133,18 +144,19 @@ struct LaunchConv2DOp<CPUDevice, T> {
|
||||
"NHWC tensor format for now."));
|
||||
return;
|
||||
}
|
||||
// TODO(reedwm): Enable explicit padding on the CPU.
|
||||
OP_REQUIRES(
|
||||
ctx, padding != Padding::EXPLICIT,
|
||||
errors::Unimplemented("Generic conv implementation does not support "
|
||||
"EXPLICIT padding yet."));
|
||||
const int64 in_depth = GetTensorDim(input, data_format, 'C');
|
||||
OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
|
||||
errors::Unimplemented("Generic conv implementation does not "
|
||||
"support grouped convolutions for now."));
|
||||
for (int64 explicit_padding : explicit_paddings) {
|
||||
if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
|
||||
ctx->SetStatus(errors::InvalidArgument("filter too large"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
||||
row_dilation, col_dilation, padding, output,
|
||||
data_format);
|
||||
row_dilation, col_dilation, padding,
|
||||
explicit_paddings, output, data_format);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1313,6 +1313,10 @@ struct gemm_pack_rhs<
|
||||
* (aka atrous convolution), sampling every col_in_stride, row_in_stride input
|
||||
* pixels.
|
||||
*
|
||||
* If padding_top, padding_bottom, padding_left, or padding_right is specified,
|
||||
* then those paddings will be used to pad the input, and padding_type must be
|
||||
* PADDING_VALID.
|
||||
*
|
||||
* The result can be assigned to a tensor of rank equal to the rank of the
|
||||
* input. The dimensions of the result will be filters, height, width (and
|
||||
* others if applicable).
|
||||
@ -1360,7 +1364,9 @@ EIGEN_DEVICE_FUNC
|
||||
const PaddingType padding_type = PADDING_SAME,
|
||||
const Index row_in_stride = 1,
|
||||
const Index col_in_stride = 1,
|
||||
const OutputKernel& output_kernel = OutputKernel()) {
|
||||
const OutputKernel& output_kernel = OutputKernel(),
|
||||
Index padding_top = 0, Index padding_bottom = 0,
|
||||
Index padding_left = 0, Index padding_right = 0) {
|
||||
typedef typename internal::traits<Input>::Index TensorIndex;
|
||||
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
|
||||
internal::traits<Input>::NumDimensions,
|
||||
@ -1402,25 +1408,33 @@ EIGEN_DEVICE_FUNC
|
||||
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
|
||||
const TensorIndex InputCols =
|
||||
isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
|
||||
const bool padding_explicit =
|
||||
(padding_top || padding_bottom || padding_left || padding_right);
|
||||
|
||||
TensorIndex out_height;
|
||||
TensorIndex out_width;
|
||||
switch (padding_type) {
|
||||
case PADDING_VALID:
|
||||
out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
|
||||
case PADDING_VALID: {
|
||||
const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
|
||||
const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
|
||||
out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) /
|
||||
static_cast<float>(row_stride));
|
||||
out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
|
||||
out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) /
|
||||
static_cast<float>(col_stride));
|
||||
break;
|
||||
case PADDING_SAME:
|
||||
}
|
||||
case PADDING_SAME: {
|
||||
eigen_assert(!padding_explicit);
|
||||
out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
|
||||
out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
|
||||
break;
|
||||
default:
|
||||
}
|
||||
default: {
|
||||
// Initialize unused variables to avoid a compiler warning
|
||||
out_height = 0;
|
||||
out_width = 0;
|
||||
eigen_assert(false && "unexpected padding");
|
||||
}
|
||||
}
|
||||
|
||||
// Molds the output of the patch extraction code into a 2d tensor:
|
||||
@ -1473,22 +1487,50 @@ EIGEN_DEVICE_FUNC
|
||||
kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
|
||||
kernel_dims[1] = kernelFilters;
|
||||
}
|
||||
return choose(
|
||||
Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||
kernel.reshape(kernel_dims)
|
||||
.contract(input
|
||||
.extract_image_patches(
|
||||
kernelRows, kernelCols, row_stride, col_stride,
|
||||
row_in_stride, col_in_stride, padding_type)
|
||||
.reshape(pre_contract_dims),
|
||||
contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims),
|
||||
input
|
||||
.extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
|
||||
row_in_stride, col_in_stride, padding_type)
|
||||
.reshape(pre_contract_dims)
|
||||
.contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims));
|
||||
if (padding_explicit) {
|
||||
return choose(
|
||||
Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||
kernel.reshape(kernel_dims)
|
||||
.contract(input
|
||||
.extract_image_patches(
|
||||
kernelRows, kernelCols, row_stride, col_stride,
|
||||
row_in_stride, col_in_stride,
|
||||
/*row_inflate_stride=*/1,
|
||||
/*col_inflate_stride=*/1, padding_top,
|
||||
padding_bottom, padding_left, padding_right,
|
||||
/*padding_value=*/0)
|
||||
.reshape(pre_contract_dims),
|
||||
contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims),
|
||||
input
|
||||
.extract_image_patches(kernelRows, kernelCols, row_stride,
|
||||
col_stride, row_in_stride, col_in_stride,
|
||||
/*row_inflate_stride=*/1,
|
||||
/*col_inflate_stride=*/1, padding_top,
|
||||
padding_bottom, padding_left, padding_right,
|
||||
/*padding_value=*/0)
|
||||
.reshape(pre_contract_dims)
|
||||
.contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims));
|
||||
} else {
|
||||
return choose(
|
||||
Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||
kernel.reshape(kernel_dims)
|
||||
.contract(input
|
||||
.extract_image_patches(
|
||||
kernelRows, kernelCols, row_stride, col_stride,
|
||||
row_in_stride, col_in_stride, padding_type)
|
||||
.reshape(pre_contract_dims),
|
||||
contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims),
|
||||
input
|
||||
.extract_image_patches(kernelRows, kernelCols, row_stride,
|
||||
col_stride, row_in_stride, col_in_stride,
|
||||
padding_type)
|
||||
.reshape(pre_contract_dims)
|
||||
.contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
|
||||
.reshape(post_contract_dims));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -403,7 +403,6 @@ class Conv2DTest(test.TestCase):
|
||||
padding,
|
||||
expected,
|
||||
dilations,
|
||||
gpu_only=True,
|
||||
test_grappler_layout_optimizer=test_grappler_layout_optimizer,
|
||||
tol=tol,
|
||||
fp16_tol=fp16_tol)
|
||||
@ -1429,8 +1428,14 @@ class Conv2DTest(test.TestCase):
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
use_gpu,
|
||||
dilations=(1, 1),
|
||||
err=2e-5):
|
||||
if use_gpu and not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
if not use_gpu and dilations != (1, 1):
|
||||
return # Non-default dilations is currently not supported on the CPU.
|
||||
|
||||
x1 = self._CreateNumpyTensor(filter_sizes)
|
||||
x2 = self._CreateNumpyTensor(output_sizes)
|
||||
dilations = list(dilations)
|
||||
@ -1455,133 +1460,128 @@ class Conv2DTest(test.TestCase):
|
||||
padding,
|
||||
expected,
|
||||
data_format,
|
||||
use_gpu=True,
|
||||
use_gpu=use_gpu,
|
||||
err=err,
|
||||
dilations=dilations)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding0x0BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 1, 2, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 1, 2, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 4, 2],
|
||||
filter_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[1, 1, 2, 3],
|
||||
strides=[2, 2],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 4, 2],
|
||||
filter_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[1, 1, 2, 3],
|
||||
strides=[2, 2],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding1x1BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 2],
|
||||
output_sizes=[1, 3, 4, 2],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format, err=1e-4)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 2],
|
||||
output_sizes=[1, 3, 4, 2],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=1e-4)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 2],
|
||||
filter_sizes=[1, 1, 2, 1],
|
||||
output_sizes=[1, 4, 3, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 2],
|
||||
filter_sizes=[1, 1, 2, 1],
|
||||
output_sizes=[1, 4, 3, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 4, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 4, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 2])
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 4, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 4, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 2], use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding2x2BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[2, 3, 1, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[2, 2, 5, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[2, 3, 1, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[2, 2, 5, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 6, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 3])
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 6, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 3],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 10, 8, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format, err=5e-5)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 10, 8, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=5e-5)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 5, 3, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 4, 8, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 5, 3, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 4, 8, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropInput(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
err=5e-5)
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
err=5e-5,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 4, 2, 1],
|
||||
filter_sizes=[3, 3, 1, 1],
|
||||
output_sizes=[1, 5, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 1])
|
||||
self._RunAndVerifyBackpropInputExplicitPadding(
|
||||
input_sizes=[1, 4, 2, 1],
|
||||
filter_sizes=[3, 3, 1, 1],
|
||||
output_sizes=[1, 5, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 1],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _RunAndVerifyBackpropFilterExplicitPadding(self,
|
||||
input_sizes,
|
||||
@ -1590,8 +1590,14 @@ class Conv2DTest(test.TestCase):
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
use_gpu,
|
||||
dilations=(1, 1),
|
||||
err=1e-5):
|
||||
if use_gpu and not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
if not use_gpu and dilations != (1, 1):
|
||||
return # Non-default dilations is currently not supported on the CPU.
|
||||
|
||||
x0 = self._CreateNumpyTensor(input_sizes)
|
||||
x2 = self._CreateNumpyTensor(output_sizes)
|
||||
dilations = list(dilations)
|
||||
@ -1613,135 +1619,127 @@ class Conv2DTest(test.TestCase):
|
||||
padding,
|
||||
expected,
|
||||
data_format,
|
||||
use_gpu=True,
|
||||
use_gpu=use_gpu,
|
||||
dilations=dilations,
|
||||
err=err)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding0x0BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 1, 2, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 1, 2, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format, use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 4, 2],
|
||||
filter_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[1, 1, 2, 3],
|
||||
strides=[2, 2],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 4, 2],
|
||||
filter_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[1, 1, 2, 3],
|
||||
strides=[2, 2],
|
||||
padding=[[0, 0], [0, 0]],
|
||||
data_format=data_format, use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding1x1BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 2],
|
||||
output_sizes=[1, 3, 4, 2],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
err=5e-5)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 2],
|
||||
output_sizes=[1, 3, 4, 2],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=5e-5)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 2],
|
||||
filter_sizes=[1, 1, 2, 1],
|
||||
output_sizes=[1, 4, 3, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 2],
|
||||
filter_sizes=[1, 1, 2, 1],
|
||||
output_sizes=[1, 4, 3, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 4, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 4, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 2])
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 4, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 4, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[1, 1], [1, 1]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
dilations=[2, 2])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding2x2BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[2, 3, 1, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[2, 2, 5, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[2, 3, 1, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[2, 2, 5, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 6, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 3])
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 6, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[2, 2], [2, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
dilations=[2, 3])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 10, 8, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format,
|
||||
err=1e-4)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[1, 10, 8, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=1e-4)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 5, 3, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 4, 8, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
data_format=data_format)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 5, 3, 1],
|
||||
filter_sizes=[3, 2, 1, 1],
|
||||
output_sizes=[1, 4, 8, 1],
|
||||
strides=[3, 1],
|
||||
padding=[[1, 8], [4, 2]],
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropFilter(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
err=1e-4)
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
filter_sizes=[2, 1, 1, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
strides=[1, 1],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=1e-4)
|
||||
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 4, 2, 1],
|
||||
filter_sizes=[3, 3, 1, 1],
|
||||
output_sizes=[1, 5, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
dilations=[2, 1])
|
||||
self._RunAndVerifyBackpropFilterExplicitPadding(
|
||||
input_sizes=[1, 4, 2, 1],
|
||||
filter_sizes=[3, 3, 1, 1],
|
||||
output_sizes=[1, 5, 2, 1],
|
||||
strides=[1, 2],
|
||||
padding=[[5, 0], [2, 2]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
dilations=[2, 1])
|
||||
|
||||
# Gradient checkers
|
||||
def ConstructAndTestGradient(self,
|
||||
@ -2107,257 +2105,221 @@ class Conv2DTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient1x1PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.0025)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.0025)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient1x1PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient1x1PaddingStrideTwo(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=4,
|
||||
input_cols=5,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=2,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=4,
|
||||
input_cols=5,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=2,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient1x1PaddingStrideTwo(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=4,
|
||||
input_cols=5,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=2,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=4,
|
||||
input_cols=5,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=2,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient2x2PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient2x2PaddingStrideOne(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.003)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=5,
|
||||
input_cols=4,
|
||||
filter_rows=3,
|
||||
filter_cols=3,
|
||||
in_depth=2,
|
||||
out_depth=3,
|
||||
stride_rows=1,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
max_err=0.003)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient1_2_3_4PaddingStride3x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=8,
|
||||
input_cols=5,
|
||||
filter_rows=4,
|
||||
filter_cols=2,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=3,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=8,
|
||||
input_cols=5,
|
||||
filter_rows=4,
|
||||
filter_cols=2,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=3,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient1_2_3_4PaddingStride3x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=8,
|
||||
input_cols=5,
|
||||
filter_rows=4,
|
||||
filter_cols=2,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=3,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=8,
|
||||
input_cols=5,
|
||||
filter_rows=4,
|
||||
filter_cols=2,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=3,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient4_3_2_1PaddingStride2x1(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=3,
|
||||
input_rows=5,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=2,
|
||||
in_depth=1,
|
||||
out_depth=2,
|
||||
stride_rows=2,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=3,
|
||||
input_rows=5,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=2,
|
||||
in_depth=1,
|
||||
out_depth=2,
|
||||
stride_rows=2,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient4_3_2_1PaddingStride2x1(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=3,
|
||||
input_rows=5,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=2,
|
||||
in_depth=1,
|
||||
out_depth=2,
|
||||
stride_rows=2,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=3,
|
||||
input_rows=5,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=2,
|
||||
in_depth=1,
|
||||
out_depth=2,
|
||||
stride_rows=2,
|
||||
stride_cols=1,
|
||||
padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testInputGradient0_0_0_5PaddingStride1x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=6,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=4,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=1,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=6,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=4,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=1,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
|
||||
test_input=True,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testFilterGradient0_0_0_5PaddingStride1x2(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
if use_gpu:
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=6,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=4,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=1,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_rows=6,
|
||||
input_cols=7,
|
||||
filter_rows=3,
|
||||
filter_cols=4,
|
||||
in_depth=3,
|
||||
out_depth=2,
|
||||
stride_rows=1,
|
||||
stride_cols=2,
|
||||
padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
|
||||
test_input=False,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testShapeFunctionEdgeCases(self):
|
||||
@ -2505,31 +2467,29 @@ class Conv2DTest(test.TestCase):
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[[0, 0], [2, 2], [2, 2], [0, 0]]))
|
||||
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
with self.test_session(use_gpu=True):
|
||||
# Negative padding during backprop.
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"nonnegative"):
|
||||
sess.run(
|
||||
nn_ops.conv2d_backprop_input([32, 20, 20, 3],
|
||||
array_ops.placeholder(
|
||||
dtypes.float32,
|
||||
shape=[18, 18, 3, 2]),
|
||||
array_ops.placeholder(
|
||||
dtypes.float32,
|
||||
shape=[32, 3, 2, 2]),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[[0, 0], [-1, 0], [0, 0],
|
||||
[0, 0]]))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"nonnegative"):
|
||||
sess.run(
|
||||
nn_ops.conv2d_backprop_filter(
|
||||
array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
|
||||
[18, 18, 3, 2],
|
||||
array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]))
|
||||
# Negative padding during backprop.
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"nonnegative"):
|
||||
sess.run(
|
||||
nn_ops.conv2d_backprop_input([32, 20, 20, 3],
|
||||
array_ops.placeholder(
|
||||
dtypes.float32,
|
||||
shape=[18, 18, 3, 2]),
|
||||
array_ops.placeholder(
|
||||
dtypes.float32,
|
||||
shape=[32, 3, 2, 2]),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[[0, 0], [-1, 0], [0, 0],
|
||||
[0, 0]]))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"nonnegative"):
|
||||
sess.run(
|
||||
nn_ops.conv2d_backprop_filter(
|
||||
array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
|
||||
[18, 18, 3, 2],
|
||||
array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]))
|
||||
|
||||
|
||||
class DepthwiseConv2DTest(test.TestCase):
|
||||
|
@ -1904,7 +1904,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
|
||||
value is given it is replicated in the `H` and `W` dimension. By default
|
||||
the `N` and `C` dimensions are set to 1. The dimension order is determined
|
||||
by the value of `data_format`, see below for details.
|
||||
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
|
||||
padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
|
||||
padding algorithm to use, or a list indicating the explicit paddings at
|
||||
the start and end of each dimension. When explicit padding is used and
|
||||
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
|
||||
|
Loading…
Reference in New Issue
Block a user