Rollforward of 93c22a269b. Windows and RCOM builds have been fixed.

PiperOrigin-RevId: 329379040
Change-Id: I51aa3c9d6374639be5d632c42f56ba68d49f5e98
This commit is contained in:
Pankaj Kanwar 2020-08-31 14:27:39 -07:00 committed by TensorFlower Gardener
parent 4317a43467
commit 458d0906eb
28 changed files with 1157 additions and 211 deletions

View File

@ -183,11 +183,13 @@
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE>
<ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -149,6 +149,7 @@ def LegalizeMaxPool2D : Pat<
IsIntList1XY1:$ksize,
IsIntList1XY1:$strides,
$padding,
$explicit_paddings,
IsDataFormatNHWC:$format),
(TFL_MaxPool2DOp $value,
/*padding=*/$padding,

View File

@ -6311,7 +6311,8 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_AnyStrAttrOf<["NHWC", "NCHW", "NCHW_VECT_C"]>, "NHWC">:$data_format
);
@ -6380,7 +6381,8 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format
);

View File

@ -1269,6 +1269,15 @@ func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x
return %0 : tensor<2x8x4x7x7xf32>
}
// CHECK-LABEL: maxpool_explicit_padding
func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
// CHECK: tf.MaxPool
// TODO(b/165938852): need to support explicit padding in max_pool.
%0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
return %0 : tensor<2x3x5x7xi32>
}
//===----------------------------------------------------------------------===//
// MaxPoolGrad op legalizations.
//===----------------------------------------------------------------------===//

View File

@ -2474,6 +2474,12 @@ class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
Type element_type =
op.input().getType().template cast<TensorType>().getElementType();
if (!element_type.isSignlessIntOrFloat()) return failure();
tensorflow::Padding padding;
if (!GetPaddingFromString(op.padding().str(), &padding).ok())
return failure();
if (padding == tensorflow::Padding::EXPLICIT) {
return failure();
}
Location loc = op.getLoc();
ConstOp init = GetScalarLimitConstOfType(element_type, loc,
hlo::kInfinityLowest, &rewriter);

View File

@ -1477,7 +1477,8 @@ Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status MaxPoolShape(shape_inference::InferenceContext* c) {
Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
bool supports_explicit_padding) {
string data_format_str;
TensorFormat data_format;
Status s = c->GetAttr("data_format", &data_format_str);
@ -1530,14 +1531,39 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
std::vector<int64> explicit_paddings;
if (supports_explicit_padding) {
Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
// Use the default value, which is an empty list, if the attribute is not
// found. Otherwise return the error to the caller.
if (!status.ok() && !errors::IsNotFound(status)) {
return status;
}
TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
/*num_dims=*/4, data_format));
} else {
DCHECK(padding != Padding::EXPLICIT);
}
ShapeHandle output_shape;
DimensionHandle output_rows, output_cols, output_depth;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
int64 pad_rows_before = -1, pad_rows_after = -1;
int64 pad_cols_before = -1, pad_cols_after = -1;
if (padding == Padding::EXPLICIT) {
GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
&pad_rows_before, &pad_rows_after);
GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
&pad_cols_before, &pad_cols_after);
}
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
pad_rows_before, pad_rows_after, &output_rows));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
pad_cols_before, pad_cols_after, &output_cols));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
/*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
{output_rows, output_cols},
@ -1547,6 +1573,14 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status MaxPoolShape(shape_inference::InferenceContext* c) {
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
}
Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
}
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
string data_format_str;
TensorFormat data_format;

View File

@ -168,7 +168,11 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
// Shape function for MaxPool-like operations.
// Shape function for MaxPool-like operations that support explicit padding.
Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c);
// Shape function for MaxPool-like operations that do not support explicit
// padding.
Status MaxPoolShape(shape_inference::InferenceContext* c);
// Shape function for MaxPoolV2-like operations.

View File

@ -81,8 +81,13 @@ class AvgPoolingOp : public UnaryOp<T> {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -146,8 +151,13 @@ class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -169,14 +179,14 @@ class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
#if CUDNN_VERSION >= 7300
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
stride_, padding_, data_format_, tensor_in,
output_shape,
stride_, padding_, /*explicit_paddings=*/{},
data_format_, tensor_in, output_shape,
/*propagate_nans=*/false);
#else
if (data_format_ == FORMAT_NCHW) {
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
stride_, padding_, data_format_, tensor_in,
output_shape,
stride_, padding_, /*explicit_paddings=*/{},
data_format_, tensor_in, output_shape,
/*propagate_nans=*/false);
} else {
Tensor* output = nullptr;
@ -446,10 +456,10 @@ class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
return;
}
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
ksize_, stride_, padding_, data_format_,
nullptr, nullptr, out_backprop, output_shape,
/*propagate_nans=*/false);
DnnPoolingGradOp<T>::Compute(
context, se::dnn::PoolingMode::kAverage, ksize_, stride_, padding_,
/*explicit_paddings=*/{}, data_format_, nullptr, nullptr, out_backprop,
output_shape, /*propagate_nans=*/false);
}
private:
@ -533,7 +543,8 @@ class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
#if CUDNN_VERSION >= 7300
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
ksize_, stride_, padding_, data_format_,
ksize_, stride_, padding_,
/*explicit_paddings=*/{}, data_format_,
nullptr, nullptr, out_backprop, output_shape,
/*propagate_nans=*/false);
#else
@ -589,7 +600,8 @@ class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
context->eigen_gpu_device()); // d
} else {
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
ksize_, stride_, padding_, data_format_,
ksize_, stride_, padding_,
/*explicit_paddings=*/{}, data_format_,
nullptr, nullptr, out_backprop, output_shape,
/*propagate_nans=*/false);
}

View File

@ -316,7 +316,7 @@ struct PadInput {
const std::array<int, NDIMS - 2>& padding_left,
const std::array<int, NDIMS - 2>& padding_right,
typename TTypes<T, NDIMS, IndexType>::Tensor out,
TensorFormat format) {
TensorFormat format, T padding_value = T{}) {
Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding;
padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0};
for (int i = 0; i < NDIMS - 2; ++i) {
@ -324,7 +324,7 @@ struct PadInput {
padding_left[i], padding_right[i]};
}
padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0};
out.device(d) = in.pad(padding);
out.device(d) = in.pad(padding, padding_value);
}
};

View File

@ -417,14 +417,12 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
}
// A Gpu custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
// the left and the top.
template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNHWC(int nthreads,
const T* __restrict__ input,
Dimension<NDIMS> input_dims,
T* __restrict__ output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
__global__ void PadInputCustomKernelNHWC(
int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
T* __restrict__ output, Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left, T padding_value) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<NDIMS> output_tensor_index =
@ -444,18 +442,16 @@ __global__ void PadInputCustomKernelNHWC(int nthreads,
const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
output[output_index] = T(0);
output[output_index] = padding_value;
}
}
}
template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNCHW(int nthreads,
const T* __restrict__ input,
Dimension<NDIMS> input_dims,
T* __restrict__ output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
__global__ void PadInputCustomKernelNCHW(
int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
T* __restrict__ output, Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left, T padding_value) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<NDIMS> output_tensor_index =
@ -475,7 +471,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads,
const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
output[output_index] = T(0);
output[output_index] = padding_value;
}
}
}
@ -572,7 +568,7 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
const std::array<int, NDIMS - 2>& padding_left,
const std::array<int, NDIMS - 2>& padding_right,
typename TTypes<T, NDIMS, int>::Tensor out,
TensorFormat format) {
TensorFormat format, T padding_value = T{}) {
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
Dimension<NDIMS> input_dims;
for (int i = 0; i < NDIMS; ++i) {
@ -589,12 +585,14 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
TF_CHECK_OK(GpuLaunchKernel(
PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
in.data(), input_dims, out.data(), output_dims, padding_left_dim,
padding_value));
} else if (format == FORMAT_NCHW) {
TF_CHECK_OK(GpuLaunchKernel(
PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
in.data(), input_dims, out.data(), output_dims, padding_left_dim,
padding_value));
} else {
LOG(FATAL) << "Invalid data format: " << format;
}

View File

@ -1135,19 +1135,20 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
T padding_value); \
extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(float);

View File

@ -1333,19 +1333,20 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
T padding_value); \
extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(float);

View File

@ -1082,7 +1082,8 @@ namespace functor {
const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
const std::array<int, 3>& padding_left, \
const std::array<int, 3>& padding_right, \
typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
T padding_value);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);

View File

@ -1183,7 +1183,8 @@ namespace functor {
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
T padding_value); \
extern template struct PadInput<GPUDevice, T, int, 4>
DECLARE_GPU_SPEC(float);

View File

@ -539,7 +539,8 @@ namespace functor {
const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
const std::array<int, 3>& padding_left, \
const std::array<int, 3>& padding_right, \
typename TTypes<T, 5, int>::Tensor out, TensorFormat format); \
typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
T padding_value); \
template <> \
void NHWCToNCHW<GPUDevice, T, 5>::operator()( \
const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \

View File

@ -804,19 +804,20 @@ class FusedConv2DOp : public OpKernel {
#if GOOGLE_CUDA
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
T padding_value); \
extern template struct PadInput<GPUDevice, T, int, 4>
// Registration of the GPU implementations.

View File

@ -105,8 +105,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
const int32 depth = params.depth;
const int32 in_rows = params.tensor_in_rows;
const int32 in_cols = params.tensor_in_cols;
const int32 pad_rows = params.pad_rows;
const int32 pad_cols = params.pad_cols;
const int32 pad_top = params.pad_top;
const int32 pad_left = params.pad_left;
const int32 window_rows = params.window_rows;
const int32 window_cols = params.window_cols;
const int32 row_stride = params.row_stride;
@ -131,8 +131,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
for (int w = 0; w < in_cols; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
const int hpad = h + pad_rows;
const int wpad = w + pad_cols;
const int hpad = h + pad_top;
const int wpad = w + pad_left;
const int h_start =
(hpad < window_rows) ? 0 : (hpad - window_rows) / row_stride + 1;
const int h_end = std::min(hpad / row_stride + 1, out_height);
@ -243,6 +243,13 @@ class MaxPoolingGradOp : public OpKernel {
}
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
if (padding_ == Padding::EXPLICIT) {
OP_REQUIRES_OK(
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
/*num_dims=*/4, data_format_));
}
}
void Compute(OpKernelContext* context) override {
@ -297,8 +304,13 @@ class MaxPoolingGradOp : public OpKernel {
errors::Unimplemented(
"MaxPoolingGrad is not yet supported on the depth dimension."));
PoolParameters params{context, ksize, stride,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize,
stride,
padding_,
explicit_paddings_,
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -316,6 +328,7 @@ class MaxPoolingGradOp : public OpKernel {
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
};
@ -347,7 +360,12 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
"Pooling is not yet supported on the batch dimension."));
}
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
if (padding_ == Padding::EXPLICIT) {
OP_REQUIRES_OK(
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
/*num_dims=*/4, data_format_));
}
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
@ -392,16 +410,26 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
int64 pad_top, pad_bottom, pad_left, pad_right;
if (padding_ == Padding::EXPLICIT) {
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H',
/*pad_top=*/&pad_top,
/*pad_bottom=*/&pad_bottom);
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W',
/*pad_left=*/&pad_left,
/*pad_right=*/&pad_right);
}
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
stride, padding_, data_format_, &tensor_in,
&tensor_out, out_backprop, output_shape,
propagate_nans_);
stride, padding_, explicit_paddings_,
data_format_, &tensor_in, &tensor_out,
out_backprop, output_shape, propagate_nans_);
}
private:
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
bool propagate_nans_;
};
@ -492,8 +520,13 @@ class MaxPoolingGradGradOp : public OpKernel {
errors::Unimplemented(
"MaxPoolingGrad is not yet supported on the depth dimension."));
PoolParameters params{context, ksize, stride,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize,
stride,
padding_,
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{2}, 0, tensor_out.shape(), &output));
@ -551,8 +584,8 @@ class MaxPoolingGradGradOp : public OpKernel {
const int32 depth = params.depth;
const int32 in_rows = params.tensor_in_rows;
const int32 in_cols = params.tensor_in_cols;
const int32 pad_rows = params.pad_rows;
const int32 pad_cols = params.pad_cols;
const int32 pad_top = params.pad_top;
const int32 pad_left = params.pad_left;
const int32 window_rows = params.window_rows;
const int32 window_cols = params.window_cols;
const int32 row_stride = params.row_stride;
@ -574,9 +607,9 @@ class MaxPoolingGradGradOp : public OpKernel {
for (int pw = 0; pw < out_width; ++pw) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
int h_start = ph * row_stride - pad_rows;
int h_start = ph * row_stride - pad_top;
const int h_end = std::min(h_start + window_rows, in_rows);
int w_start = pw * col_stride - pad_cols;
int w_start = pw * col_stride - pad_left;
const int w_end = std::min(w_start + window_cols, in_cols);
h_start = std::max(h_start, 0);
w_start = std::max(w_start, 0);
@ -691,15 +724,20 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize,
stride,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
functor::MaxPoolGradBackwardNoMask<T>()(
data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
params.tensor_in_batch, params.out_height, params.out_width,
params.depth, params.tensor_in_rows, params.tensor_in_cols,
params.window_rows, params.window_cols, params.row_stride,
params.col_stride, params.pad_rows, params.pad_cols,
params.col_stride, params.pad_top, params.pad_left,
out_grad_backprop.flat<T>().data(), output->flat<T>().data(),
context->eigen_device<Eigen::GpuDevice>());
}
@ -743,13 +781,22 @@ class MaxPoolingNoMaskOp : public OpKernel {
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
OP_REQUIRES(
context, padding_ != EXPLICIT,
errors::Unimplemented(
"Explicit padding is not supported for MaxPoolingNoMaskOp."));
}
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -826,8 +873,13 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize,
stride,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -889,8 +941,13 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -1003,8 +1060,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
const Tensor& grad_in = context->input(1);
const Tensor& argmax = context->input(2);
PoolParameters params{context, ksize_, stride_,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -1056,8 +1118,13 @@ class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
const Tensor& grad_in = context->input(1);
const Tensor& argmax = context->input(2);
PoolParameters params{context, ksize_, stride_,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -1100,6 +1167,8 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
errors::InvalidArgument("Sliding window stride field must "
"specify 4 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
OP_REQUIRES_OK(context,
context->GetAttr("explicit_paddings", &explicit_paddings_));
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
@ -1113,8 +1182,9 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{
context, ksize_, stride_, padding_, explicit_paddings_,
data_format_, tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -1131,15 +1201,20 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
#if CUDNN_VERSION >= 7300
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
stride_, padding_, data_format_, tensor_in,
out_shape, propagate_nans_);
stride_, padding_, explicit_paddings_,
data_format_, tensor_in, out_shape,
propagate_nans_);
#else
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
if (!is_int8x4 && data_format_ == FORMAT_NCHW) {
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
stride_, padding_, data_format_, tensor_in,
out_shape, propagate_nans_);
stride_, padding_, explicit_paddings_,
data_format_, tensor_in, out_shape,
propagate_nans_);
} else {
OP_REQUIRES(context, padding_ != EXPLICIT,
errors::Unimplemented("Explicit padding is not supported ",
"when CUDNN is not enabled."));
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
if (is_int8x4) {
@ -1165,6 +1240,7 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
bool propagate_nans_;
};
@ -1228,8 +1304,13 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{context,
ksize,
stride,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -1239,8 +1320,9 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
params.out_width, params.depth);
if (data_format_ == FORMAT_NCHW) {
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
stride, padding_, data_format_, tensor_in,
out_shape, propagate_nans_);
stride, padding_, explicit_paddings_,
data_format_, tensor_in, out_shape,
propagate_nans_);
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "MaxPool only supports NCHW or NHWC format";
@ -1255,6 +1337,7 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
bool propagate_nans_;
};
@ -1267,7 +1350,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth, params.out_height,
params.out_width, params.window_rows, params.window_cols,
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
params.row_stride, params.col_stride, params.pad_top, params.pad_left,
output->flat<T>().data(), nullptr, context->eigen_gpu_device(),
propagate_nans, false);
if (!status) {
@ -1286,7 +1369,7 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth, params.out_height,
params.out_width, params.window_rows, params.window_cols,
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
params.row_stride, params.col_stride, params.pad_top, params.pad_left,
output->flat<T>().data(),
reinterpret_cast<int64*>(argmax->flat<int64>().data()),
context->eigen_gpu_device(), propagate_nans, include_batch_in_index);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -49,12 +50,75 @@ struct RawType<qint8> {
using type = int8;
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
struct PadInputWithNegativeInf {
Status operator()(const GPUDevice& d,
typename TTypes<T, 4, int>::ConstTensor in,
int input_pad_top, int input_pad_bottom, int input_pad_left,
int input_pad_right, typename TTypes<T, 4, int>::Tensor out,
TensorFormat format) {
T padding_value = -std::numeric_limits<T>::infinity();
functor::PadInput<GPUDevice, T, int, 4>()(
d, in, {{input_pad_top, input_pad_left}},
{{input_pad_bottom, input_pad_right}}, out, format, padding_value);
return Status::OK();
}
};
template <>
struct PadInputWithNegativeInf<qint8> {
Status operator()(const GPUDevice& d,
typename TTypes<qint8, 4, int>::ConstTensor in,
int input_pad_top, int input_pad_bottom, int input_pad_left,
int input_pad_right,
typename TTypes<qint8, 4, int>::Tensor out,
TensorFormat format) {
return errors::InvalidArgument(
"Explicit padding not yet supported with qint8");
}
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace
Status CheckPaddingSize(int64 window_rows, int64 window_cols, int64 pad_top,
int64 pad_bottom, int64 pad_left, int64 pad_right) {
if (!FastBoundsCheck(pad_top, window_rows)) {
return errors::InvalidArgument("Top padding ", pad_top,
" needs to be smaller than the "
"window size ",
window_rows);
}
if (!FastBoundsCheck(pad_bottom, window_rows)) {
return errors::InvalidArgument("Bottom padding ", pad_bottom,
" needs to be smaller than the "
"window size ",
window_rows);
}
if (!FastBoundsCheck(pad_left, window_cols)) {
return errors::InvalidArgument("Left padding ", pad_left,
" needs to be smaller than the "
"window size ",
window_cols);
}
if (!FastBoundsCheck(pad_right, window_cols)) {
return errors::InvalidArgument("Right padding ", pad_right,
" needs to be smaller than the "
"window size ",
window_cols);
}
return Status::OK();
}
PoolParameters::PoolParameters(OpKernelContext* context,
const std::vector<int32>& ksize,
const std::vector<int32>& stride,
Padding padding, TensorFormat data_format,
Padding padding,
std::vector<int64> explicit_paddings,
TensorFormat data_format,
const TensorShape& tensor_in_shape) {
// For maxpooling, tensor_in should have 2 spatial dimensions.
// Note: the total number of dimensions could be 4 for NHWC, NCHW,
@ -85,14 +149,24 @@ PoolParameters::PoolParameters(OpKernelContext* context,
errors::Unimplemented(
"MaxPooling supports exactly one of pooling across depth "
"or pooling across width/height."));
if (padding == Padding::EXPLICIT) {
OP_REQUIRES_OK(context, CheckValidPadding(padding, explicit_paddings,
/*num_dims=*/4, data_format));
GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &pad_top,
&pad_bottom);
GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &pad_left,
&pad_right);
OP_REQUIRES_OK(context, CheckPaddingSize(window_rows, window_cols, pad_top,
pad_bottom, pad_left, pad_right));
}
if (depth_window == 1) {
OP_REQUIRES_OK(
context, GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_rows));
OP_REQUIRES_OK(
context, GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
padding, &out_width, &pad_cols));
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_top, &pad_bottom));
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_cols, window_cols, col_stride,
padding, &out_width, &pad_left, &pad_right));
pad_depth = 0;
out_depth = depth;
} else {
@ -140,6 +214,7 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding,
std::vector<int64> explicit_paddings,
TensorFormat data_format, const Tensor& tensor_in,
const TensorShape& tensor_out_shape,
bool propagate_nans) {
@ -150,14 +225,18 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
return;
}
PoolParameters params{context, size, stride,
padding, data_format, tensor_in.shape()};
PoolParameters params{
context, size, stride, padding,
explicit_paddings, data_format, tensor_in.shape()};
if (!context->status().ok()) {
return;
}
int batch_size = params.tensor_in_batch;
int depth = params.depth;
int tensor_in_cols = params.tensor_in_cols;
int tensor_in_rows = params.tensor_in_rows;
#if CUDNN_VERSION < 7300
/// Earlier versions do not support NHWC format, so we need to convert it
/// to NCHW before calling cudnn. We need to get rid of this once it is done
@ -186,7 +265,7 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
}
se::dnn::DataLayout data_layout = se::dnn::DataLayout::kBatchDepthYX;
#else
auto& transformed_input = tensor_in;
Tensor transformed_input = tensor_in;
auto& transformed_output = *tensor_out;
se::dnn::DataLayout data_layout;
switch (data_format) {
@ -209,21 +288,81 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
ToString(data_format)));
}
#endif
/// Get ready to call cudnn
int64 vertical_padding = params.pad_top;
int64 horizontal_padding = params.pad_left;
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
params.pad_left != params.pad_right)) {
// cuDNN only supports padding the same amount on the left and right sides,
// and on the top and bottom sides. So we manually create a new padded
// input tensor such that we can pass it to cuDNN.
const int64 common_padding_rows =
std::min(params.pad_top, params.pad_bottom);
const int64 common_padding_cols =
std::min(params.pad_left, params.pad_right);
Tensor padded_input;
const int64 padding_rows_diff =
std::abs(params.pad_top - params.pad_bottom);
const int64 padding_cols_diff =
std::abs(params.pad_left - params.pad_right);
const int64 new_in_rows = tensor_in_rows + padding_rows_diff;
const int64 new_in_cols = tensor_in_cols + padding_cols_diff;
OP_REQUIRES_OK(
context,
context->allocate_temp(DataTypeToEnum<T>::value,
ShapeFromFormat(data_format, batch_size,
new_in_rows, new_in_cols, depth),
&padded_input));
const int64 input_pad_top = params.pad_top - common_padding_rows;
const int64 input_pad_bottom = params.pad_bottom - common_padding_rows;
const int64 input_pad_left = params.pad_left - common_padding_cols;
const int64 input_pad_right = params.pad_right - common_padding_cols;
bool in_bounds =
FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
if (!in_bounds) {
context->SetStatus(errors::InvalidArgument("Padding is too large."));
return;
}
// We need to call the const version of transformed_input.tensor()
const Tensor& const_transformed_input = transformed_input;
OP_REQUIRES_OK(
context,
PadInputWithNegativeInf<T>()(
context->eigen_device<GPUDevice>(),
To32Bit(const_transformed_input.tensor<T, 4>()),
static_cast<int>(input_pad_top), static_cast<int>(input_pad_bottom),
static_cast<int>(input_pad_left), static_cast<int>(input_pad_right),
To32Bit(padded_input.tensor<T, 4>()), data_format));
transformed_input = padded_input;
vertical_padding = common_padding_rows;
horizontal_padding = common_padding_cols;
tensor_in_rows = new_in_rows;
tensor_in_cols = new_in_cols;
}
se::dnn::PoolingDescriptor pooling_desc;
pooling_desc.set_pooling_mode(pooling_mode)
.set_window_height(params.window_rows)
.set_window_width(params.window_cols)
.set_vertical_stride(params.row_stride)
.set_horizontal_stride(params.col_stride)
.set_vertical_padding(params.pad_rows)
.set_horizontal_padding(params.pad_cols)
.set_vertical_padding(vertical_padding)
.set_horizontal_padding(horizontal_padding)
.set_propagate_nans(propagate_nans);
se::dnn::BatchDescriptor input_desc;
input_desc.set_count(batch_size)
.set_height(params.tensor_in_rows)
.set_width(params.tensor_in_cols)
.set_height(tensor_in_rows)
.set_width(tensor_in_cols)
.set_feature_map_count(depth)
.set_layout(data_layout);
@ -280,13 +419,32 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
#endif
}
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
T padding_value); \
extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(int32);
} // namespace functor
template <typename T>
void DnnPoolingGradOp<T>::Compute(
OpKernelContext* context, se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size, const std::vector<int32>& stride,
Padding padding, TensorFormat data_format, const Tensor* tensor_in,
const Tensor* tensor_out, const Tensor& out_backprop,
const TensorShape& tensor_in_shape, bool propagate_nans) {
Padding padding, std::vector<int64> explicit_paddings,
TensorFormat data_format, const Tensor* tensor_in, const Tensor* tensor_out,
const Tensor& out_backprop, const TensorShape& tensor_in_shape,
bool propagate_nans) {
CHECK((pooling_mode != se::dnn::PoolingMode::kMaximum) ||
(tensor_in && tensor_out))
<< "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
@ -299,8 +457,8 @@ void DnnPoolingGradOp<T>::Compute(
return;
}
PoolParameters params{context, size, stride,
padding, data_format, tensor_in_shape};
PoolParameters params{context, size, stride, padding,
explicit_paddings, data_format, tensor_in_shape};
if (!context->status().ok()) {
return;
}
@ -406,6 +564,98 @@ void DnnPoolingGradOp<T>::Compute(
}
#endif // CUDNN_VERSION < 7300
int64 vertical_padding = params.pad_top;
int64 horizontal_padding = params.pad_left;
int batch_size = params.tensor_in_batch;
int depth = params.depth;
int tensor_in_cols = params.tensor_in_cols;
int tensor_in_rows = params.tensor_in_rows;
int64 input_pad_top = 0;
int64 input_pad_bottom = 0;
int64 input_pad_left = 0;
int64 input_pad_right = 0;
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
params.pad_left != params.pad_right)) {
// Pad the input in the same way we did during the forward pass, so that
// cuDNN or MIOpen receives the same input during the backward pass function
// as it did during the forward pass function.
const int64 common_padding_rows =
std::min(params.pad_top, params.pad_bottom);
const int64 common_padding_cols =
std::min(params.pad_left, params.pad_right);
Tensor padded_input;
Tensor padded_input_backprop;
const int64 padding_rows_diff =
std::abs(params.pad_top - params.pad_bottom);
const int64 padding_cols_diff =
std::abs(params.pad_left - params.pad_right);
const int64 new_in_rows = tensor_in_rows + padding_rows_diff;
const int64 new_in_cols = tensor_in_cols + padding_cols_diff;
VLOG(2) << "Create new tensor: "
<< " original rows=" << tensor_in_rows
<< " original cols=" << tensor_in_cols
<< " padding_rows=" << new_in_rows
<< " padding_cols=" << new_in_cols << " depth= " << depth
<< " batch_size=" << batch_size << " kernel_rows"
<< params.window_rows << " kernel_col" << params.window_cols
<< " stride_rows" << params.row_stride;
OP_REQUIRES_OK(
context,
context->allocate_temp(DataTypeToEnum<T>::value,
ShapeFromFormat(data_format, batch_size,
new_in_rows, new_in_cols, depth),
&padded_input));
OP_REQUIRES_OK(
context,
context->allocate_temp(DataTypeToEnum<T>::value,
ShapeFromFormat(data_format, batch_size,
new_in_rows, new_in_cols, depth),
&transformed_input_backprop));
input_pad_top = params.pad_top - common_padding_rows;
input_pad_bottom = params.pad_bottom - common_padding_rows;
input_pad_left = params.pad_left - common_padding_cols;
input_pad_right = params.pad_right - common_padding_cols;
bool in_bounds =
FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
if (!in_bounds) {
context->SetStatus(errors::InvalidArgument("Padding is too large."));
return;
}
// PadInputWithNegativeInf functor requires input to be a const.
const Tensor& const_transformed_input = transformed_input;
OP_REQUIRES_OK(
context,
PadInputWithNegativeInf<T>()(
context->eigen_device<GPUDevice>(),
To32Bit(const_transformed_input.tensor<T, 4>()),
static_cast<int>(input_pad_top), static_cast<int>(input_pad_bottom),
static_cast<int>(input_pad_left), static_cast<int>(input_pad_right),
To32Bit(padded_input.tensor<T, 4>()), data_format));
transformed_input = padded_input;
vertical_padding = common_padding_rows;
horizontal_padding = common_padding_cols;
VLOG(2) << "vertical padding set to: " << vertical_padding
<< " horizontal padding set to: " << horizontal_padding;
tensor_in_rows = new_in_rows;
tensor_in_cols = new_in_cols;
}
/// Get ready to call cudnn
se::dnn::PoolingDescriptor pooling_desc;
pooling_desc.set_pooling_mode(pooling_mode)
@ -413,8 +663,8 @@ void DnnPoolingGradOp<T>::Compute(
.set_window_width(params.window_cols)
.set_vertical_stride(params.row_stride)
.set_horizontal_stride(params.col_stride)
.set_vertical_padding(params.pad_rows)
.set_horizontal_padding(params.pad_cols)
.set_vertical_padding(vertical_padding)
.set_horizontal_padding(horizontal_padding)
.set_propagate_nans(propagate_nans);
se::dnn::BatchDescriptor orig_output_desc;
@ -426,8 +676,8 @@ void DnnPoolingGradOp<T>::Compute(
se::dnn::BatchDescriptor orig_input_desc;
orig_input_desc.set_count(params.tensor_in_batch)
.set_height(params.tensor_in_rows)
.set_width(params.tensor_in_cols)
.set_height(tensor_in_rows)
.set_width(tensor_in_cols)
.set_feature_map_count(params.depth)
.set_layout(data_layout);
@ -482,6 +732,18 @@ void DnnPoolingGradOp<T>::Compute(
input_backprop->tensor<T, 4>());
}
#endif // CUDNN_VERSION < 7300
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
params.pad_left != params.pad_right)) {
// Remove the padding that was added to the input shape above.
functor::PadInput<GPUDevice, T, int, 4>()(
context->eigen_device<GPUDevice>(),
To32Bit(const_cast<const Tensor&>(transformed_input_backprop)
.tensor<T, 4>()),
{{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
{{static_cast<int>(-input_pad_bottom),
static_cast<int>(-input_pad_right)}},
To32Bit(input_backprop->tensor<T, 4>()), data_format);
}
}
#define DEFINE_DNN_OPS(T) \

View File

@ -18,7 +18,12 @@ limitations under the License.
#include <vector>
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -40,9 +45,12 @@ typedef Eigen::GpuDevice GPUDevice;
// A helper class to manage sizes and shapes for pooling operations.
struct PoolParameters {
// Updates context->status if there is an invalid input.
// explicit_paddings has eight elements if padding==EXPLIICT, and zero
// elements otherwise.
PoolParameters(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const TensorShape& tensor_in_shape);
std::vector<int64> explicit_paddings, TensorFormat data_format,
const TensorShape& tensor_in_shape);
// Returns the shape of the output for "forward" pooling operations.
TensorShape forward_output_shape();
@ -65,13 +73,21 @@ struct PoolParameters {
int64 out_width;
int out_depth;
int64 pad_rows;
int64 pad_cols;
int64 pad_top;
int64 pad_bottom;
int64 pad_left;
int64 pad_right;
int pad_depth;
TensorFormat data_format;
};
// Checks if the sizes of the paddings are less than the size of window.
// This is required for MaxPool because it pads with -inf, so the pooling
// window cannot fully cover the padded area.
Status CheckPaddingSize(PoolParameters& params);
// An implementation of MaxPooling (forward).
// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
@ -106,6 +122,10 @@ class MaxPoolingOp : public OpKernel {
errors::InvalidArgument("Sliding window stride field must "
"specify 4 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
if (padding_ == Padding::EXPLICIT) {
OP_REQUIRES_OK(
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
}
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
@ -113,8 +133,9 @@ class MaxPoolingOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{
context, ksize_, stride_, padding_, explicit_paddings_,
FORMAT_NHWC, tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -134,9 +155,21 @@ class MaxPoolingOp : public OpKernel {
context, params.depth_window == params.depth_stride,
errors::Unimplemented("Depthwise max pooling requires "
"the depth window to equal the depth stride."));
OP_REQUIRES(
context, padding_ != EXPLICIT,
errors::Unimplemented("Depthwise max pooling does not support "
"explicit padding."));
DepthwiseMaxPool(context, output, tensor_in, params);
} else {
// MaxPoolingOp is only called on the GPU when the eigen_tensor label
// is used. In this case, explicit padding is not supported
if (std::is_same<Device, GPUDevice>::value &&
padding_ == Padding::EXPLICIT) {
context->SetStatus(errors::Unimplemented(
"MaxPoolingOp does not support explicit padding."));
return;
}
SpatialMaxPool(context, output, tensor_in, params, padding_);
}
}
@ -202,8 +235,8 @@ class MaxPoolingOp : public OpKernel {
auto shard = [&params, &in_mat, &out_mat](int64 start, int64 limit) {
const int32 in_rows = params.tensor_in_rows;
const int32 in_cols = params.tensor_in_cols;
const int32 pad_rows = params.pad_rows;
const int32 pad_cols = params.pad_cols;
const int32 pad_top = params.pad_top;
const int32 pad_left = params.pad_left;
const int32 window_rows = params.window_rows;
const int32 window_cols = params.window_cols;
const int32 row_stride = params.row_stride;
@ -225,8 +258,8 @@ class MaxPoolingOp : public OpKernel {
for (int32 w = 0; w < in_cols; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
const int32 hpad = h + pad_rows;
const int32 wpad = w + pad_cols;
const int32 hpad = h + pad_top;
const int32 wpad = w + pad_left;
const int32 h_start = (hpad < window_rows)
? 0
: (hpad - window_rows) / row_stride + 1;
@ -263,6 +296,7 @@ class MaxPoolingOp : public OpKernel {
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
};
@ -280,7 +314,7 @@ struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
params.depth, params.out_height, params.out_width, params.window_rows,
params.window_cols, params.row_stride, params.col_stride,
params.pad_rows, params.pad_cols,
params.pad_top, params.pad_left,
reinterpret_cast<int32*>(output->flat<qint8>().data()),
context->eigen_gpu_device());
if (!status) {
@ -358,8 +392,15 @@ class MaxPoolingV2Op : public OpKernel {
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
PoolParameters params{
context,
ksize,
stride,
padding_,
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape(),
};
if (!context->status().ok()) {
return;
}
@ -455,8 +496,8 @@ class MaxPoolingV2Op : public OpKernel {
auto shard = [&params, &in_mat, &out_mat](int64 start, int64 limit) {
const int32 in_rows = params.tensor_in_rows;
const int32 in_cols = params.tensor_in_cols;
const int32 pad_rows = params.pad_rows;
const int32 pad_cols = params.pad_cols;
const int32 pad_top = params.pad_top;
const int32 pad_left = params.pad_left;
const int32 window_rows = params.window_rows;
const int32 window_cols = params.window_cols;
const int32 row_stride = params.row_stride;
@ -478,8 +519,8 @@ class MaxPoolingV2Op : public OpKernel {
for (int32 w = 0; w < in_cols; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
const int32 hpad = h + pad_rows;
const int32 wpad = w + pad_cols;
const int32 hpad = h + pad_top;
const int32 wpad = w + pad_left;
const int32 h_start = (hpad < window_rows)
? 0
: (hpad - window_rows) / row_stride + 1;
@ -567,8 +608,8 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
for (int w = 0; w < params.tensor_in_cols; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
const int hpad = h + params.pad_rows;
const int wpad = w + params.pad_cols;
const int hpad = h + params.pad_top;
const int wpad = w + params.pad_left;
const int h_start =
(hpad < params.window_rows)
? 0

View File

@ -43,6 +43,7 @@ class DnnPoolingOp {
se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding,
std::vector<int64> explicit_paddings,
TensorFormat data_format, const Tensor& tensor_in,
const TensorShape& tensor_out_shape, bool propagate_nans);
};
@ -58,6 +59,7 @@ class DnnPoolingGradOp {
se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding,
std::vector<int64> explicit_paddings,
TensorFormat data_format, const Tensor* tensor_in,
const Tensor* tensor_out, const Tensor& out_backprop,
const TensorShape& tensor_in_shape, bool propagate_nans);

View File

@ -54,8 +54,13 @@ class QuantizedAvgPoolingOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& tensor_in = context->input(0);
PoolParameters params{context, ksize_, stride_,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context,
ksize_,
stride_,
padding_,
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}

View File

@ -846,11 +846,12 @@ REGISTER_OP("MaxPool")
"uint16, qint8} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Input("input: T")
.Output("output: T")
.SetShapeFn(shape_inference::MaxPoolShape);
.SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding);
REGISTER_OP("MaxPoolV2")
.Attr(
@ -870,7 +871,8 @@ REGISTER_OP("MaxPoolV2")
REGISTER_OP("MaxPoolGrad")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Input("orig_input: T")
.Input("orig_output: T")
@ -895,6 +897,7 @@ REGISTER_OP("MaxPoolGradV2")
return UnchangedShapeWithRank(c, 4);
});
// TODO(b/150813181): Implement explicit padding.
REGISTER_OP("MaxPoolGradGrad")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")

View File

@ -337,13 +337,13 @@ def NHWCToNCHW(input_tensor):
"""Converts the input from the NHWC format to NCHW.
Args:
input_tensor: a 4- or 5-D tensor, or an array representing shape
input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
Returns:
converted tensor or shape array
"""
# tensor dim -> new axis order
new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
if isinstance(input_tensor, ops.Tensor):
ndims = input_tensor.shape.ndims
return array_ops.transpose(input_tensor, new_axes[ndims])

View File

@ -50,15 +50,21 @@ def GetDeviceScope(self, use_gpu=False):
return self.session(use_gpu=use_gpu)
def GetTestConfigs(include_nchw_vect_c=False):
def GetTestConfigs(include_nchw_vect_c=False, one_dimensional=False):
"""Get all the valid tests configs to run.
Args:
include_nchw_vect_c: Whether to include NCHW_VECT_C in the test configs.
one_dimensional: If it's a 1D test
Returns:
all the valid test configs as tuples of data_format and use_gpu.
"""
if one_dimensional:
test_configs = [("NWC", False), ("NWC", True)]
if test.is_gpu_available(cuda_only=True):
test_configs += [("NCW", True)]
return test_configs
test_configs = [("NHWC", False), ("NHWC", True)]
if not test.is_gpu_available(cuda_only=True):
tf_logging.info("NCHW and NCHW_VECT_C tests skipped because not run with "
@ -106,8 +112,12 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
class PoolingTest(test.TestCase):
def _isMaxPool(self, func):
return func in (nn_ops.max_pool, nn_ops.max_pool_v2)
def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
data_format, data_type, expected, use_gpu, v2):
data_format, data_type, expected, use_gpu, v2,
use_negative_input=False):
"""Verifies the output values of the pooling function.
Args:
@ -121,6 +131,8 @@ class PoolingTest(test.TestCase):
data_type: The data type to use to run the pooling operation.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
v2: Whether to use v2 version.
use_negative_input: If the input values should be negative.
"""
total_size = 1
for s in input_sizes:
@ -141,10 +153,11 @@ class PoolingTest(test.TestCase):
data_type)
# Initializes the input tensor with array containing incrementing
# numbers from 1, wrapping round to -127 after 127 to support int8.
x = [((f + 128) % 255) - 127 for f in range(total_size)]
y = -1 if use_negative_input else 1
x = [(((f + 128) % 255) - 127)*y for f in range(total_size)]
with self.cached_session(use_gpu=use_gpu):
t = constant_op.constant(x, shape=input_sizes, dtype=data_type)
if data_format in ("NCHW", "NCHW_VECT_C"):
if data_format in ("NCHW", "NCHW_VECT_C", "NCW"):
if data_format == "NCHW_VECT_C":
t = test_util.NHWCToNCHW_VECT_C(t)
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
@ -152,6 +165,8 @@ class PoolingTest(test.TestCase):
t = test_util.NHWCToNCHW(t)
ksize = test_util.NHWCToNCHW(ksize)
strides = test_util.NHWCToNCHW(strides)
if isinstance(padding, list):
padding = test_util.NHWCToNCHW(padding)
ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
if v2:
@ -184,7 +199,8 @@ class PoolingTest(test.TestCase):
self.assertAllCloseAccordingToType(expected, actual.flatten())
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected, use_gpu, v2):
data_format, expected, use_gpu, v2,
use_negative_input=False):
"""Verifies the output values of the pooling function.
Args:
@ -197,6 +213,8 @@ class PoolingTest(test.TestCase):
data_format: The data format we use to run the pooling operation.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
v2: Whether to use v2 version.
use_negative_input: If the input values should be negative."
"""
if data_format == "NCHW_VECT_C":
avg_pool_func = nn_ops.avg_pool
@ -204,17 +222,24 @@ class PoolingTest(test.TestCase):
if pool_func == avg_pool_func:
tf_logging.info("NCHW_VECT_C not yet implemented for avg_pool")
return
if (self._isMaxPool(pool_func) and isinstance(padding, list)):
tf_logging.info("NCHW_VECT_C not yet implemented for max pool" +
" with explicit padding")
return
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float32, expected, use_gpu, v2)
data_format, dtypes.float32, expected, use_gpu, v2,
use_negative_input)
if not test.is_built_with_rocm():
# double datatype is not supported for pooling ops on the ROCm platform
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float64, expected, use_gpu, v2)
data_format, dtypes.float64, expected, use_gpu, v2,
use_negative_input)
if not use_gpu or test_util.GpuSupportsHalfMatMulAndConv():
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float16, expected, use_gpu, v2)
data_format, dtypes.float16, expected, use_gpu, v2,
use_negative_input)
def _VerifyValues(self,
pool_func,
@ -224,7 +249,9 @@ class PoolingTest(test.TestCase):
padding,
expected,
use_gpu,
v2=False):
v2=False,
one_dim=False,
use_negative_input=False):
"""Verifies the output values of the pooling function.
Args:
@ -236,11 +263,16 @@ class PoolingTest(test.TestCase):
padding: Padding type.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
v2: Whether to use v2 version.
one_dim: If one dimensional pools should be done instead of two
dimensional pools.
use_negative_input: If the input values should be negative.
"""
for (data_format, use_gpu_2) in GetTestConfigs(True):
for (data_format, use_gpu_2) in GetTestConfigs(True, one_dim):
if use_gpu_2 == use_gpu:
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
data_format, expected, use_gpu, v2)
data_format, expected, use_gpu, v2,
use_negative_input)
def _testAvgPoolValidPadding(self, use_gpu):
expected_output = [7.0, 8.0, 9.0]
@ -467,6 +499,101 @@ class PoolingTest(test.TestCase):
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolZeroExplicitPadding(self, use_gpu):
expected_output = [9.0]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 3, 3, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
expected=expected_output,
use_gpu=use_gpu)
def _testMaxPoolNegativeInputExpPadding(self, use_gpu):
expected_output = [-1, -1, -1, -1]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 3, 3, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [2, 1], [2, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu,
use_negative_input=True)
def _testMaxPoolExplicitPadding(self, use_gpu):
expected_output = [9.0, 9.0]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 3, 3, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [0, 2], [0, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu)
def _testMaxPoolExplicitPaddingAdvanced(self, use_gpu):
expected_output = [7, 9, 11, 12, 19, 21, 23, 24, 31, 33, 35, 36, 31, 33,
35, 36]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 6, 6, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [1, 2], [2, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu)
def _testMaxPoolNegativeInputExpPaddingAdv(self, use_gpu):
expected_output = [-1, -1, -3, -5, -7, -7, -9, -11, -19, -19, -21, -23, -31,
-31, -33, -35]
self._VerifyValues(
nn_ops.max_pool,
input_sizes=[1, 6, 6, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [1, 2], [2, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu,
use_negative_input=True)
def _testMaxPoolExplicitPaddingV2(self, use_gpu):
expected_output = [9.0, 9.0]
self._VerifyValues(
nn_ops.max_pool_v2,
input_sizes=[1, 3, 3, 1],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [0, 2], [0, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu)
def _testMaxPoolExplicitPadding1D(self, use_gpu):
expected_output = [2.0, 3.0]
self._VerifyValues(
nn_ops.max_pool1d,
input_sizes=[1, 3, 1],
ksize=[1, 2, 1],
strides=[1, 2, 1],
padding=[[0, 0], [0, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu,
one_dim=True)
def _testMaxPoolExplicitPadding1dV2(self, use_gpu):
expected_output = [2.0, 3.0]
self._VerifyValues(
nn_ops.max_pool_v2,
input_sizes=[1, 3, 1],
ksize=[1, 2, 1],
strides=[1, 2, 1],
padding=[[0, 0], [0, 1], [0, 0]],
expected=expected_output,
use_gpu=use_gpu,
one_dim=True)
def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
# input is:
# [1.0, 2.0
@ -618,6 +745,14 @@ class PoolingTest(test.TestCase):
self._testMaxPoolSamePaddingPacket4(use_gpu)
self._testMaxPoolSamePaddingPacket8(use_gpu)
self._testMaxPoolEmptyInput(use_gpu)
self._testMaxPoolZeroExplicitPadding(use_gpu)
self._testMaxPoolExplicitPadding(use_gpu)
self._testMaxPoolExplicitPaddingV2(use_gpu)
self._testMaxPoolExplicitPadding1D(use_gpu)
self._testMaxPoolExplicitPadding1dV2(use_gpu)
self._testMaxPoolExplicitPaddingAdvanced(use_gpu)
self._testMaxPoolNegativeInputExpPadding(use_gpu)
self._testMaxPoolNegativeInputExpPaddingAdv(use_gpu)
# Tests for DepthwiseMaxPooling on CPU only.
@test_util.run_deprecated_v1
@ -980,7 +1115,7 @@ class PoolingTest(test.TestCase):
data_format,
use_gpu,
x_init_value=None):
"""Verifies the gradients of the avg pooling function.
"""Verifies the gradients of the max or avg pooling function.
Args:
pool_func: Function to be called, co.MaxPool, co.AvgPool,
@ -1017,11 +1152,13 @@ class PoolingTest(test.TestCase):
func_name = "max_pool"
err_tolerance = 1e-3
if data_format == "NCHW":
ksize = [1, 1, window_rows, window_rows]
ksize = [1, 1, window_rows, window_cols]
strides = [1, 1, row_stride, col_stride]
if isinstance(padding, list):
padding = test_util.NHWCToNCHW(padding)
t = test_util.NHWCToNCHW(input_tensor)
else:
ksize = [1, window_rows, window_rows, 1]
ksize = [1, window_rows, window_cols, 1]
strides = [1, row_stride, col_stride, 1]
t = input_tensor
t = pool_func(
@ -1261,6 +1398,76 @@ class PoolingTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
def _testMaxPoolExplicitPadding1(self, data_format, use_gpu):
for pool_func in [nn_ops.max_pool]:
self._ConstructAndTestGradient(
pool_func,
input_sizes=[1, 7, 7, 1],
output_sizes=[1, 7, 7, 1],
window_rows=3,
window_cols=3,
row_stride=1,
col_stride=1,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
data_format=data_format,
use_gpu=use_gpu)
def _testMaxPoolExplicitPadding2(self, data_format, use_gpu):
for pool_func in [nn_ops.max_pool]:
self._ConstructAndTestGradient(
pool_func,
input_sizes=[1, 7, 7, 1],
output_sizes=[1, 6, 8, 1],
window_rows=3,
window_cols=5,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 1], [2, 3], [0, 0]],
data_format=data_format,
use_gpu=use_gpu)
def _testMaxPoolExplicitPaddingLeftGreater(self, data_format, use_gpu):
for pool_func in [nn_ops.max_pool]:
self._ConstructAndTestGradient(
pool_func,
input_sizes=[1, 7, 7, 1],
output_sizes=[1, 6, 8, 1],
window_rows=3,
window_cols=5,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 1], [3, 2], [0, 0]],
data_format=data_format,
use_gpu=use_gpu)
def _testMaxPoolExplicitPaddingBatchChannel(self, data_format, use_gpu):
for pool_func in [nn_ops.max_pool]:
self._ConstructAndTestGradient(
pool_func,
input_sizes=[4, 7, 7, 3],
output_sizes=[4, 6, 8, 3],
window_rows=3,
window_cols=5,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 1], [3, 2], [0, 0]],
data_format=data_format,
use_gpu=use_gpu)
def _testMaxPoolExplicitPaddingStrides(self, data_format, use_gpu):
for pool_func in [nn_ops.max_pool]:
self._ConstructAndTestGradient(
pool_func,
input_sizes=[1, 7, 7, 1],
output_sizes=[1, 4, 3, 1],
window_rows=3,
window_cols=3,
row_stride=2,
col_stride=3,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
data_format=data_format,
use_gpu=use_gpu)
@test_util.run_deprecated_v1
def testMaxPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
@ -1274,6 +1481,11 @@ class PoolingTest(test.TestCase):
self._testMaxPoolGradSamePadding2_1(data_format, use_gpu)
self._testMaxPoolGradSamePadding2_2(data_format, use_gpu)
self._testMaxPoolGradSamePadding3_1(data_format, use_gpu)
self._testMaxPoolExplicitPadding1(data_format, use_gpu)
self._testMaxPoolExplicitPadding2(data_format, use_gpu)
self._testMaxPoolExplicitPaddingStrides(data_format, use_gpu)
self._testMaxPoolExplicitPaddingLeftGreater(data_format, use_gpu)
self._testMaxPoolExplicitPaddingBatchChannel(data_format, use_gpu)
def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
window_cols, row_stride, col_stride, padding, v2):
@ -1294,9 +1506,16 @@ class PoolingTest(test.TestCase):
A Tensor.
"""
pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops.max_pool_grad
return pool_func(orig_input, orig_output, grad,
[1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding)
if v2:
return pool_func(orig_input, orig_output, grad,
[1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding)
else:
padding, explicit_paddings = nn_ops.convert_padding(padding)
return pool_func(orig_input, orig_output, grad,
[1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding,
explicit_paddings)
def _testMaxPoolGradDirect(self, input_data, output_backprop,
expected_input_backprop, input_sizes, output_sizes,
@ -1439,6 +1658,116 @@ class PoolingTest(test.TestCase):
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolGradZeroExplicitPadding(self):
input_data = [
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
0.0, 1.0
]
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
expected_input_backprop = [
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 17.0, 19.0, 0.0, 41.0, 0.0, 0.0,
0.0, 0.0, 0.0
]
for use_gpu in True, False:
for v2 in [False]:
self._testMaxPoolGradDirect(
input_data,
output_backprop,
expected_input_backprop,
input_sizes=[1, 4, 4, 1],
output_sizes=[1, 3, 3, 1],
window_rows=2,
window_cols=2,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolGradExplicitPadding_1(self):
input_data = [
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
0.0, 1.0
]
output_backprop = [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
20.0, 21.0, 22.0]
expected_input_backprop = [
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 49.0, 19.0, 0.0, 41.0, 0.0, 0.0,
0.0, 0.0, 22.0
]
for use_gpu in True, False:
for v2 in [False]:
self._testMaxPoolGradDirect(
input_data,
output_backprop,
expected_input_backprop,
input_sizes=[1, 4, 4, 1],
output_sizes=[1, 3, 4, 1],
window_rows=2,
window_cols=2,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 0], [0, 1], [0, 0]],
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolGradExplicitPadding_2(self):
input_data = [
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
0.0, 1.0
]
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
expected_input_backprop = [
54.0, 0.0, 30.0, 0.0, 0.0, 0.0, 0.0, 0.0, 39.0, 0.0, 21.0, 0.0, 0.0,
0.0, 0.0, 0.0
]
for use_gpu in True, False:
for v2 in [False]:
self._testMaxPoolGradDirect(
input_data,
output_backprop,
expected_input_backprop,
input_sizes=[1, 4, 4, 1],
output_sizes=[1, 3, 3, 1],
window_rows=3,
window_cols=3,
row_stride=2,
col_stride=2,
padding=[[0, 0], [2, 1], [2, 1], [0, 0]],
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolGradExplicitPadding_3(self):
input_data = [
-1.0, -5.0, -1.0, -5.0, -5.0, -1.0, -5.0, -1.0, -1.0, -5.0, -1.0, -5.0,
-5.0, -1.0, -5.0, -1.0
]
output_backprop = [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
20.0, 21.0, 22.0]
expected_input_backprop = [
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 49.0, 19.0, 0.0, 41.0, 0.0, 0.0,
0.0, 0.0, 22.0
]
for use_gpu in True, False:
for v2 in [False]:
self._testMaxPoolGradDirect(
input_data,
output_backprop,
expected_input_backprop,
input_sizes=[1, 4, 4, 1],
output_sizes=[1, 3, 4, 1],
window_rows=2,
window_cols=2,
row_stride=1,
col_stride=1,
padding=[[0, 0], [0, 0], [0, 1], [0, 0]],
use_gpu=use_gpu,
v2=v2)
@test_util.no_xla_auto_jit("b/123923733") # NaNs handled differently
def _testMaxPoolGradDirectWithNans2_1(self):
input_data = [float("nan")] * 16
@ -1615,6 +1944,10 @@ class PoolingTest(test.TestCase):
self._testMaxPoolGradDirect1_3()
self._testMaxPoolGradDirectWithNans2_1()
self._testMaxPoolGradDirectWithNans2_2()
self._testMaxPoolGradZeroExplicitPadding()
self._testMaxPoolGradExplicitPadding_1()
self._testMaxPoolGradExplicitPadding_2()
self._testMaxPoolGradExplicitPadding_3()
def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu):
for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]:
@ -1956,6 +2289,94 @@ class PoolingTest(test.TestCase):
strides=[1, 1, 1, 1],
padding="VALID")
@test_util.run_deprecated_v1
def _testEdgeCasesRaiseErrors(self):
with self.assertRaisesRegexp(
ValueError, "Data formats NCHW_VECT_C is not yet supported with "
"explicit padding"):
nn_ops.max_pool(
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1]),
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding=[[0, 0], [0, 1], [0, 1], [0, 0]],
data_format="NCHW_VECT_C")
with self.assertRaisesRegexp(
ValueError, "Explicit padding is not yet supported with an input "
"tensor of rank 5"):
nn_ops.max_pool_v2(
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1, 1]),
ksize=[1, 2, 2, 1, 1],
strides=[1, 2, 2, 1, 1],
padding=[[0, 0], [0, 1], [0, 1], [0, 0]],
data_format="NCHW")
with self.assertRaisesRegexp(
ValueError, "Attr 'padding' of 'MaxPoolV2' Op passed "
"string 'EXPLICIT'"):
gen_nn_ops.max_pool_v2(
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1, 1]),
ksize=[1, 2, 2, 1, 1],
strides=[1, 2, 2, 1, 1],
padding="EXPLICIT",
data_format="NHWC")
@test_util.run_deprecated_v1
def _testEdgeCasesExcessPadding(self):
with self.session(use_gpu=test.is_gpu_available()) as sess:
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Right padding 2 needs to be smaller than the window size 2"):
input_sizes = [1, 3, 3, 1]
x = [(((f + 128) % 255) - 127) for f in range(9)]
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
sess.run(gen_nn_ops.max_pool(
t,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="EXPLICIT",
explicit_paddings=[0, 0, 0, 1, 0, 2, 0, 0],
data_format="NHWC"))
@test_util.run_deprecated_v1
def _testNegativePadding(self):
with self.session(use_gpu=test.is_gpu_available()) as sess:
with self.assertRaisesRegexp(
ValueError, "All elements of explicit_paddings must be "
"nonnegative for"):
input_sizes = [1, 3, 3, 1]
x = [(((f + 128) % 255) - 127) for f in range(9)]
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
sess.run(gen_nn_ops.max_pool(
t,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="EXPLICIT",
explicit_paddings=[0, 0, -1, -1, -1, -1, 0, 0],
data_format="NHWC"))
@test_util.run_deprecated_v1
def _testExplicitPaddingBatch(self):
with self.session(use_gpu=test.is_gpu_available()) as sess:
with self.assertRaisesRegexp(
ValueError, "Nonzero explicit padding in the batch or depth "
"dimensions is not supported"):
input_sizes = [1, 3, 3, 1]
x = [(((f + 128) % 255) - 127) for f in range(9)]
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
sess.run(gen_nn_ops.max_pool(
t,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="EXPLICIT",
explicit_paddings=[1, 1, 1, 1, 1, 1, 0, 0],
data_format="NHWC"))
def testExplicitPaddingEdgeCases(self):
# Tests for Explicit padding.
self._testEdgeCasesRaiseErrors()
self._testEdgeCasesExcessPadding()
self._testExplicitPaddingBatch()
self._testNegativePadding()
def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):

View File

@ -688,6 +688,7 @@ def _MaxPoolGrad(op, grad):
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))

View File

@ -1752,12 +1752,14 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
name=name)
def convert_padding(padding):
def convert_padding(padding, expected_length=4):
"""Converts Python padding to C++ padding for ops which take EXPLICIT padding.
Args:
padding: the `padding` argument for a Python op which supports EXPLICIT
padding.
expected_length: Expected number of entries in the padding list when
explicit padding is used.
Returns:
(padding, explicit_paddings) pair, which should be passed as attributes to a
@ -1783,9 +1785,9 @@ def convert_padding(padding):
"be a list/tuple of size 2. Element with index %d of "
"padding has size %d" % (i, len(dim_paddings)))
explicit_paddings.extend(dim_paddings)
if len(padding) != 4:
raise ValueError("When padding is a list, it must be of size 4. Got "
"padding of size: %d" % len(padding))
if len(padding) != expected_length:
raise ValueError("When padding is a list, it must be of size %d. Got "
"padding of size: %d" % (expected_length, len(padding)))
padding = "EXPLICIT"
return padding, explicit_paddings
@ -4481,8 +4483,15 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
of the window for each dimension of the input tensor.
strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
stride of the sliding window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
the "returns" section of `tf.nn.convolution` for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
padding, the size of the paddings cannot be greater than the sliding
window size.
data_format: A string. Specifies the channel dimension. For N=1 it can be
either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
@ -4508,12 +4517,20 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
else:
channel_index = 1 if data_format.startswith("NC") else n + 1
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
"explicit padding")
ksize = _get_sequence(ksize, n, channel_index, "ksize")
strides = _get_sequence(strides, n, channel_index, "strides")
if (isinstance(padding, (list, tuple)) and n == 3):
raise ValueError("Explicit padding is not yet supported with an input "
"tensor of rank 5")
max_pooling_ops = {
1: max_pool1d,
2: gen_nn_ops.max_pool,
2: max_pool2d,
3: gen_nn_ops.max_pool3d
}
@ -4545,8 +4562,15 @@ def max_pool(value,
The size of the window for each dimension of the input tensor.
strides: An int or list of `ints` that has length `1`, `2` or `4`.
The stride of the sliding window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
padding, the size of the paddings cannot be greater than the sliding
window size.
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
name: Optional name for the operation.
input: Alias for value.
@ -4563,6 +4587,10 @@ def max_pool(value,
ksize = _get_sequence(ksize, 2, channel_index, "ksize")
strides = _get_sequence(strides, 2, channel_index, "strides")
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
"explicit padding")
padding, explicit_paddings = convert_padding(padding)
if ((np.isscalar(ksize) and ksize == 0) or
(isinstance(ksize,
(list, tuple, np.ndarray)) and any(v == 0 for v in ksize))):
@ -4573,6 +4601,7 @@ def max_pool(value,
ksize=ksize,
strides=strides,
padding=padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
name=name)
@ -4591,8 +4620,14 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
window for each dimension of the input tensor.
strides: An int or list of `ints` that has length `1` or `3`. The stride of
the sliding window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
the "returns" section of `tf.nn.convolution` for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
data_format is `"NWC"`, this should be in the form `[[0, 0], [pad_left,
pad_right], [0, 0]]`. When explicit padding used and data_format is
`"NCW"`, this should be in the form `[[0, 0], [0, 0], [pad_left,
pad_right]]`. When using explicit padding, the size of the paddings cannot
be greater than the sliding window size.
data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
name: A name for the operation (optional).
@ -4601,11 +4636,17 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
The max pooled output tensor.
"""
with ops.name_scope(name, "MaxPool1d", [input]) as name:
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
"explicit padding")
if data_format is None:
data_format = "NWC"
channel_index = 1 if data_format.startswith("NC") else 2
ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
padding, explicit_paddings = convert_padding(padding, 3)
if padding == "EXPLICIT":
explicit_paddings = [0, 0] + explicit_paddings
expanding_dim = 1 if data_format == "NWC" else 2
data_format = "NHWC" if data_format == "NWC" else "NCHW"
@ -4616,6 +4657,7 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
ksize=ksize,
strides=strides,
padding=padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
name=name)
return array_ops.squeeze(result, expanding_dim)
@ -4634,8 +4676,15 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
the window for each dimension of the input tensor.
strides: An int or list of `ints` that has length `1`, `2` or `4`. The
stride of the sliding window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
the "returns" section of `tf.nn.convolution` for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
padding, the size of the paddings cannot be greater than the sliding
window size.
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
name: Optional name for the operation.
@ -4650,12 +4699,17 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
ksize = _get_sequence(ksize, 2, channel_index, "ksize")
strides = _get_sequence(strides, 2, channel_index, "strides")
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
"explicit padding")
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.max_pool(
input,
ksize=ksize,
strides=strides,
padding=padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
name=name)
# pylint: enable=redefined-builtin

View File

@ -2426,7 +2426,7 @@ tf_module {
}
member_method {
name: "MaxPool"
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
}
member_method {
name: "MaxPool3D"
@ -2442,7 +2442,7 @@ tf_module {
}
member_method {
name: "MaxPoolGrad"
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
}
member_method {
name: "MaxPoolGradGrad"

View File

@ -2426,7 +2426,7 @@ tf_module {
}
member_method {
name: "MaxPool"
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
}
member_method {
name: "MaxPool3D"
@ -2442,7 +2442,7 @@ tf_module {
}
member_method {
name: "MaxPoolGrad"
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
}
member_method {
name: "MaxPoolGradGrad"