Add explicit padding to max_pool.
Explicit padding is supported with CPUs, GPUs, and XLA. No XLA code needed to be changed as the existing codepaths already worked with explicit padding. Minor modifications were required for CPU and GPU code as the algorithms already had to take into account padding to support SAME padding. PiperOrigin-RevId: 328847222 Change-Id: Iff18840e3d9bf3643a676403adddca61d6ecd6ca
This commit is contained in:
parent
ab783f83f6
commit
93c22a269b
RELEASE.md
tensorflow
compiler/mlir
lite/transforms
tensorflow/ir
xla
core
framework
kernels
avgpooling_op.ccconv_2d.hconv_2d_gpu.hconv_grad_filter_ops.ccconv_grad_input_ops.ccconv_grad_ops_3d.ccconv_ops.ccconv_ops_3d.ccconv_ops_fused_impl.hmaxpooling_op.ccpooling_ops_common.ccpooling_ops_common.hpooling_ops_common_gpu.hquantized_pooling_ops.cc
ops
python
tools/api/golden
@ -183,13 +183,11 @@
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
* `tf.nn`:
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
<ADD RELEASE NOTES HERE>
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -149,7 +149,6 @@ def LegalizeMaxPool2D : Pat<
|
||||
IsIntList1XY1:$ksize,
|
||||
IsIntList1XY1:$strides,
|
||||
$padding,
|
||||
$explicit_paddings,
|
||||
IsDataFormatNHWC:$format),
|
||||
(TFL_MaxPool2DOp $value,
|
||||
/*padding=*/$padding,
|
||||
|
@ -6329,8 +6329,7 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
|
||||
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
|
||||
DefaultValuedAttr<TF_AnyStrAttrOf<["NHWC", "NCHW", "NCHW_VECT_C"]>, "NHWC">:$data_format
|
||||
);
|
||||
|
||||
@ -6399,8 +6398,7 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
|
||||
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksize,
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format
|
||||
);
|
||||
|
||||
|
@ -1269,15 +1269,6 @@ func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x
|
||||
return %0 : tensor<2x8x4x7x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: maxpool_explicit_padding
|
||||
func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
|
||||
// CHECK: tf.MaxPool
|
||||
// TODO(b/165938852): need to support explicit padding in max_pool.
|
||||
|
||||
%0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
|
||||
return %0 : tensor<2x3x5x7xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxPoolGrad op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2474,12 +2474,6 @@ class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
|
||||
Type element_type =
|
||||
op.input().getType().template cast<TensorType>().getElementType();
|
||||
if (!element_type.isSignlessIntOrFloat()) return failure();
|
||||
tensorflow::Padding padding;
|
||||
if (!GetPaddingFromString(op.padding().str(), &padding).ok())
|
||||
return failure();
|
||||
if (padding == tensorflow::Padding::EXPLICIT) {
|
||||
return failure();
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
ConstOp init = GetScalarLimitConstOfType(element_type, loc,
|
||||
hlo::kInfinityLowest, &rewriter);
|
||||
|
@ -1477,8 +1477,7 @@ Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
|
||||
bool supports_explicit_padding) {
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c) {
|
||||
string data_format_str;
|
||||
TensorFormat data_format;
|
||||
Status s = c->GetAttr("data_format", &data_format_str);
|
||||
@ -1531,39 +1530,14 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
|
||||
Padding padding;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
||||
|
||||
std::vector<int64> explicit_paddings;
|
||||
if (supports_explicit_padding) {
|
||||
Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
|
||||
// Use the default value, which is an empty list, if the attribute is not
|
||||
// found. Otherwise return the error to the caller.
|
||||
if (!status.ok() && !errors::IsNotFound(status)) {
|
||||
return status;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
|
||||
/*num_dims=*/4, data_format));
|
||||
} else {
|
||||
DCHECK(padding != Padding::EXPLICIT);
|
||||
}
|
||||
|
||||
ShapeHandle output_shape;
|
||||
DimensionHandle output_rows, output_cols, output_depth;
|
||||
int64 pad_rows_before = -1, pad_rows_after = -1;
|
||||
int64 pad_cols_before = -1, pad_cols_after = -1;
|
||||
if (padding == Padding::EXPLICIT) {
|
||||
GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
|
||||
&pad_rows_before, &pad_rows_after);
|
||||
GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
|
||||
&pad_cols_before, &pad_cols_after);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
|
||||
c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
|
||||
pad_rows_before, pad_rows_after, &output_rows));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
|
||||
c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
|
||||
pad_cols_before, pad_cols_after, &output_cols));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
|
||||
c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
|
||||
/*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
|
||||
|
||||
TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
|
||||
{output_rows, output_cols},
|
||||
@ -1573,14 +1547,6 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c) {
|
||||
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
|
||||
}
|
||||
|
||||
Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
|
||||
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
|
||||
}
|
||||
|
||||
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
|
||||
string data_format_str;
|
||||
TensorFormat data_format;
|
||||
|
@ -168,11 +168,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
|
||||
// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
|
||||
Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for MaxPool-like operations that support explicit padding.
|
||||
Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for MaxPool-like operations that do not support explicit
|
||||
// padding.
|
||||
// Shape function for MaxPool-like operations.
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for MaxPoolV2-like operations.
|
||||
|
@ -81,13 +81,8 @@ class AvgPoolingOp : public UnaryOp<T> {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -151,13 +146,8 @@ class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -179,14 +169,14 @@ class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
|
||||
|
||||
#if CUDNN_VERSION >= 7300
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
|
||||
stride_, padding_, /*explicit_paddings=*/{},
|
||||
data_format_, tensor_in, output_shape,
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
output_shape,
|
||||
/*propagate_nans=*/false);
|
||||
#else
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
|
||||
stride_, padding_, /*explicit_paddings=*/{},
|
||||
data_format_, tensor_in, output_shape,
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
output_shape,
|
||||
/*propagate_nans=*/false);
|
||||
} else {
|
||||
Tensor* output = nullptr;
|
||||
@ -456,10 +446,10 @@ class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
DnnPoolingGradOp<T>::Compute(
|
||||
context, se::dnn::PoolingMode::kAverage, ksize_, stride_, padding_,
|
||||
/*explicit_paddings=*/{}, data_format_, nullptr, nullptr, out_backprop,
|
||||
output_shape, /*propagate_nans=*/false);
|
||||
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
|
||||
ksize_, stride_, padding_, data_format_,
|
||||
nullptr, nullptr, out_backprop, output_shape,
|
||||
/*propagate_nans=*/false);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -543,8 +533,7 @@ class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
|
||||
|
||||
#if CUDNN_VERSION >= 7300
|
||||
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
|
||||
ksize_, stride_, padding_,
|
||||
/*explicit_paddings=*/{}, data_format_,
|
||||
ksize_, stride_, padding_, data_format_,
|
||||
nullptr, nullptr, out_backprop, output_shape,
|
||||
/*propagate_nans=*/false);
|
||||
#else
|
||||
|
@ -316,7 +316,7 @@ struct PadInput {
|
||||
const std::array<int, NDIMS - 2>& padding_left,
|
||||
const std::array<int, NDIMS - 2>& padding_right,
|
||||
typename TTypes<T, NDIMS, IndexType>::Tensor out,
|
||||
TensorFormat format, T padding_value = T{}) {
|
||||
TensorFormat format) {
|
||||
Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding;
|
||||
padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0};
|
||||
for (int i = 0; i < NDIMS - 2; ++i) {
|
||||
@ -324,7 +324,7 @@ struct PadInput {
|
||||
padding_left[i], padding_right[i]};
|
||||
}
|
||||
padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0};
|
||||
out.device(d) = in.pad(padding, padding_value);
|
||||
out.device(d) = in.pad(padding);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -417,12 +417,14 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
|
||||
}
|
||||
|
||||
// A Gpu custom kernel that convert input to output, given proper padding on
|
||||
// the left and the top.
|
||||
// the left and the top. The padded value is zero.
|
||||
template <typename T, int NDIMS>
|
||||
__global__ void PadInputCustomKernelNHWC(
|
||||
int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
|
||||
T* __restrict__ output, Dimension<NDIMS> output_dims,
|
||||
Dimension<NDIMS - 2> padding_left, T padding_value) {
|
||||
__global__ void PadInputCustomKernelNHWC(int nthreads,
|
||||
const T* __restrict__ input,
|
||||
Dimension<NDIMS> input_dims,
|
||||
T* __restrict__ output,
|
||||
Dimension<NDIMS> output_dims,
|
||||
Dimension<NDIMS - 2> padding_left) {
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int output_index = index;
|
||||
Index<NDIMS> output_tensor_index =
|
||||
@ -442,16 +444,18 @@ __global__ void PadInputCustomKernelNHWC(
|
||||
const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
output[output_index] = input[input_index];
|
||||
} else {
|
||||
output[output_index] = padding_value;
|
||||
output[output_index] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
__global__ void PadInputCustomKernelNCHW(
|
||||
int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
|
||||
T* __restrict__ output, Dimension<NDIMS> output_dims,
|
||||
Dimension<NDIMS - 2> padding_left, T padding_value) {
|
||||
__global__ void PadInputCustomKernelNCHW(int nthreads,
|
||||
const T* __restrict__ input,
|
||||
Dimension<NDIMS> input_dims,
|
||||
T* __restrict__ output,
|
||||
Dimension<NDIMS> output_dims,
|
||||
Dimension<NDIMS - 2> padding_left) {
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int output_index = index;
|
||||
Index<NDIMS> output_tensor_index =
|
||||
@ -471,7 +475,7 @@ __global__ void PadInputCustomKernelNCHW(
|
||||
const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
output[output_index] = input[input_index];
|
||||
} else {
|
||||
output[output_index] = padding_value;
|
||||
output[output_index] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -568,7 +572,7 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
|
||||
const std::array<int, NDIMS - 2>& padding_left,
|
||||
const std::array<int, NDIMS - 2>& padding_right,
|
||||
typename TTypes<T, NDIMS, int>::Tensor out,
|
||||
TensorFormat format, T padding_value = T{}) {
|
||||
TensorFormat format) {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
|
||||
Dimension<NDIMS> input_dims;
|
||||
for (int i = 0; i < NDIMS; ++i) {
|
||||
@ -585,14 +589,12 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||
in.data(), input_dims, out.data(), output_dims, padding_left_dim,
|
||||
padding_value));
|
||||
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
|
||||
} else if (format == FORMAT_NCHW) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||
in.data(), input_dims, out.data(), output_dims, padding_left_dim,
|
||||
padding_value));
|
||||
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid data format: " << format;
|
||||
}
|
||||
|
@ -1135,20 +1135,19 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
|
||||
T padding_value); \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
|
@ -1333,20 +1333,19 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
|
||||
T padding_value); \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
|
@ -1082,8 +1082,7 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
|
||||
const std::array<int, 3>& padding_left, \
|
||||
const std::array<int, 3>& padding_right, \
|
||||
typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
|
||||
T padding_value);
|
||||
typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
|
||||
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
|
@ -1183,8 +1183,7 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
|
||||
T padding_value); \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
|
@ -539,8 +539,7 @@ namespace functor {
|
||||
const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
|
||||
const std::array<int, 3>& padding_left, \
|
||||
const std::array<int, 3>& padding_right, \
|
||||
typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
|
||||
T padding_value); \
|
||||
typename TTypes<T, 5, int>::Tensor out, TensorFormat format); \
|
||||
template <> \
|
||||
void NHWCToNCHW<GPUDevice, T, 5>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \
|
||||
|
@ -765,20 +765,19 @@ class FusedConv2DOp : public OpKernel {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
|
||||
T padding_value); \
|
||||
#define DECLARE_FUNCTOR_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, FilterTensorFormat dst_filter_format, \
|
||||
typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
|
@ -105,8 +105,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
const int32 depth = params.depth;
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
const int32 pad_top = params.pad_top;
|
||||
const int32 pad_left = params.pad_left;
|
||||
const int32 pad_rows = params.pad_rows;
|
||||
const int32 pad_cols = params.pad_cols;
|
||||
const int32 window_rows = params.window_rows;
|
||||
const int32 window_cols = params.window_cols;
|
||||
const int32 row_stride = params.row_stride;
|
||||
@ -131,8 +131,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
for (int w = 0; w < in_cols; ++w) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
const int hpad = h + pad_top;
|
||||
const int wpad = w + pad_left;
|
||||
const int hpad = h + pad_rows;
|
||||
const int wpad = w + pad_cols;
|
||||
const int h_start =
|
||||
(hpad < window_rows) ? 0 : (hpad - window_rows) / row_stride + 1;
|
||||
const int h_end = std::min(hpad / row_stride + 1, out_height);
|
||||
@ -243,13 +243,6 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -304,13 +297,8 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
explicit_paddings_,
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -328,7 +316,6 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
@ -360,12 +347,7 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
}
|
||||
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
@ -410,26 +392,16 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
int64 pad_top, pad_bottom, pad_left, pad_right;
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H',
|
||||
/*pad_top=*/&pad_top,
|
||||
/*pad_bottom=*/&pad_bottom);
|
||||
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W',
|
||||
/*pad_left=*/&pad_left,
|
||||
/*pad_right=*/&pad_right);
|
||||
}
|
||||
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, explicit_paddings_,
|
||||
data_format_, &tensor_in, &tensor_out,
|
||||
out_backprop, output_shape, propagate_nans_);
|
||||
stride, padding_, data_format_, &tensor_in,
|
||||
&tensor_out, out_backprop, output_shape,
|
||||
propagate_nans_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
@ -520,13 +492,8 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{2}, 0, tensor_out.shape(), &output));
|
||||
@ -584,8 +551,8 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
const int32 depth = params.depth;
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
const int32 pad_top = params.pad_top;
|
||||
const int32 pad_left = params.pad_left;
|
||||
const int32 pad_rows = params.pad_rows;
|
||||
const int32 pad_cols = params.pad_cols;
|
||||
const int32 window_rows = params.window_rows;
|
||||
const int32 window_cols = params.window_cols;
|
||||
const int32 row_stride = params.row_stride;
|
||||
@ -607,9 +574,9 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
for (int pw = 0; pw < out_width; ++pw) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
int h_start = ph * row_stride - pad_top;
|
||||
int h_start = ph * row_stride - pad_rows;
|
||||
const int h_end = std::min(h_start + window_rows, in_rows);
|
||||
int w_start = pw * col_stride - pad_left;
|
||||
int w_start = pw * col_stride - pad_cols;
|
||||
const int w_end = std::min(w_start + window_cols, in_cols);
|
||||
h_start = std::max(h_start, 0);
|
||||
w_start = std::max(w_start, 0);
|
||||
@ -724,20 +691,15 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
|
||||
functor::MaxPoolGradBackwardNoMask<T>()(
|
||||
data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
|
||||
params.tensor_in_batch, params.out_height, params.out_width,
|
||||
params.depth, params.tensor_in_rows, params.tensor_in_cols,
|
||||
params.window_rows, params.window_cols, params.row_stride,
|
||||
params.col_stride, params.pad_top, params.pad_left,
|
||||
params.col_stride, params.pad_rows, params.pad_cols,
|
||||
out_grad_backprop.flat<T>().data(), output->flat<T>().data(),
|
||||
context->eigen_device<Eigen::GpuDevice>());
|
||||
}
|
||||
@ -781,22 +743,13 @@ class MaxPoolingNoMaskOp : public OpKernel {
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, padding_ != EXPLICIT,
|
||||
errors::Unimplemented(
|
||||
"Explicit padding is not supported for MaxPoolingNoMaskOp."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -873,13 +826,8 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
|
||||
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
PoolParameters params{context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -941,13 +889,8 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1060,13 +1003,8 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||
const Tensor& grad_in = context->input(1);
|
||||
const Tensor& argmax = context->input(2);
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1118,13 +1056,8 @@ class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
|
||||
const Tensor& grad_in = context->input(1);
|
||||
const Tensor& argmax = context->input(2);
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1167,8 +1100,6 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
@ -1182,9 +1113,8 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
PoolParameters params{
|
||||
context, ksize_, stride_, padding_, explicit_paddings_,
|
||||
data_format_, tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1201,20 +1131,15 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
|
||||
#if CUDNN_VERSION >= 7300
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, explicit_paddings_,
|
||||
data_format_, tensor_in, out_shape,
|
||||
propagate_nans_);
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
#else
|
||||
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
|
||||
if (!is_int8x4 && data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, explicit_paddings_,
|
||||
data_format_, tensor_in, out_shape,
|
||||
propagate_nans_);
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
} else {
|
||||
OP_REQUIRES(context, padding_ != EXPLICIT,
|
||||
errors::Unimplemented("Explicit padding is not supported ",
|
||||
"when CUDNN is not enabled."));
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
if (is_int8x4) {
|
||||
@ -1240,7 +1165,6 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
@ -1304,13 +1228,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1320,9 +1239,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
params.out_width, params.depth);
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, explicit_paddings_,
|
||||
data_format_, tensor_in, out_shape,
|
||||
propagate_nans_);
|
||||
stride, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "MaxPool only supports NCHW or NHWC format";
|
||||
@ -1337,7 +1255,6 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
@ -1350,7 +1267,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
|
||||
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
|
||||
params.tensor_in_cols, params.depth, params.out_height,
|
||||
params.out_width, params.window_rows, params.window_cols,
|
||||
params.row_stride, params.col_stride, params.pad_top, params.pad_left,
|
||||
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
|
||||
output->flat<T>().data(), nullptr, context->eigen_gpu_device(),
|
||||
propagate_nans, false);
|
||||
if (!status) {
|
||||
@ -1369,7 +1286,7 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
|
||||
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
|
||||
params.tensor_in_cols, params.depth, params.out_height,
|
||||
params.out_width, params.window_rows, params.window_cols,
|
||||
params.row_stride, params.col_stride, params.pad_top, params.pad_left,
|
||||
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
|
||||
output->flat<T>().data(),
|
||||
reinterpret_cast<int64*>(argmax->flat<int64>().data()),
|
||||
context->eigen_gpu_device(), propagate_nans, include_batch_in_index);
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -52,41 +51,10 @@ struct RawType<qint8> {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CheckPaddingSize(int64 window_rows, int64 window_cols, int64 pad_top,
|
||||
int64 pad_bottom, int64 pad_left, int64 pad_right) {
|
||||
if (!FastBoundsCheck(pad_top, window_rows)) {
|
||||
return errors::InvalidArgument("Top padding ", pad_top,
|
||||
" needs to be smaller than the "
|
||||
"window size ",
|
||||
window_rows);
|
||||
}
|
||||
if (!FastBoundsCheck(pad_bottom, window_rows)) {
|
||||
return errors::InvalidArgument("Bottom padding ", pad_bottom,
|
||||
" needs to be smaller than the "
|
||||
"window size ",
|
||||
window_rows);
|
||||
}
|
||||
if (!FastBoundsCheck(pad_left, window_cols)) {
|
||||
return errors::InvalidArgument("Left padding ", pad_left,
|
||||
" needs to be smaller than the "
|
||||
"window size ",
|
||||
window_cols);
|
||||
}
|
||||
if (!FastBoundsCheck(pad_right, window_cols)) {
|
||||
return errors::InvalidArgument("Right padding ", pad_right,
|
||||
" needs to be smaller than the "
|
||||
"window size ",
|
||||
window_cols);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PoolParameters::PoolParameters(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride,
|
||||
Padding padding,
|
||||
std::vector<int64> explicit_paddings,
|
||||
TensorFormat data_format,
|
||||
Padding padding, TensorFormat data_format,
|
||||
const TensorShape& tensor_in_shape) {
|
||||
// For maxpooling, tensor_in should have 2 spatial dimensions.
|
||||
// Note: the total number of dimensions could be 4 for NHWC, NCHW,
|
||||
@ -117,24 +85,14 @@ PoolParameters::PoolParameters(OpKernelContext* context,
|
||||
errors::Unimplemented(
|
||||
"MaxPooling supports exactly one of pooling across depth "
|
||||
"or pooling across width/height."));
|
||||
if (padding == Padding::EXPLICIT) {
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding, explicit_paddings,
|
||||
/*num_dims=*/4, data_format));
|
||||
GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &pad_top,
|
||||
&pad_bottom);
|
||||
GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &pad_left,
|
||||
&pad_right);
|
||||
OP_REQUIRES_OK(context, CheckPaddingSize(window_rows, window_cols, pad_top,
|
||||
pad_bottom, pad_left, pad_right));
|
||||
}
|
||||
|
||||
if (depth_window == 1) {
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
|
||||
tensor_in_rows, window_rows, row_stride,
|
||||
padding, &out_height, &pad_top, &pad_bottom));
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
|
||||
tensor_in_cols, window_cols, col_stride,
|
||||
padding, &out_width, &pad_left, &pad_right));
|
||||
OP_REQUIRES_OK(
|
||||
context, GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
|
||||
padding, &out_height, &pad_rows));
|
||||
OP_REQUIRES_OK(
|
||||
context, GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
|
||||
padding, &out_width, &pad_cols));
|
||||
pad_depth = 0;
|
||||
out_depth = depth;
|
||||
} else {
|
||||
@ -182,7 +140,6 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
|
||||
se::dnn::PoolingMode pooling_mode,
|
||||
const std::vector<int32>& size,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
std::vector<int64> explicit_paddings,
|
||||
TensorFormat data_format, const Tensor& tensor_in,
|
||||
const TensorShape& tensor_out_shape,
|
||||
bool propagate_nans) {
|
||||
@ -193,18 +150,14 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
|
||||
return;
|
||||
}
|
||||
|
||||
PoolParameters params{
|
||||
context, size, stride, padding,
|
||||
explicit_paddings, data_format, tensor_in.shape()};
|
||||
PoolParameters params{context, size, stride,
|
||||
padding, data_format, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int batch_size = params.tensor_in_batch;
|
||||
int depth = params.depth;
|
||||
int tensor_in_cols = params.tensor_in_cols;
|
||||
int tensor_in_rows = params.tensor_in_rows;
|
||||
|
||||
#if CUDNN_VERSION < 7300
|
||||
/// Earlier versions do not support NHWC format, so we need to convert it
|
||||
/// to NCHW before calling cudnn. We need to get rid of this once it is done
|
||||
@ -233,7 +186,7 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
|
||||
}
|
||||
se::dnn::DataLayout data_layout = se::dnn::DataLayout::kBatchDepthYX;
|
||||
#else
|
||||
Tensor transformed_input = tensor_in;
|
||||
auto& transformed_input = tensor_in;
|
||||
auto& transformed_output = *tensor_out;
|
||||
se::dnn::DataLayout data_layout;
|
||||
switch (data_format) {
|
||||
@ -256,87 +209,21 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
|
||||
ToString(data_format)));
|
||||
}
|
||||
#endif
|
||||
|
||||
int64 vertical_padding = params.pad_top;
|
||||
int64 horizontal_padding = params.pad_left;
|
||||
|
||||
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
|
||||
params.pad_left != params.pad_right)) {
|
||||
// cuDNN only supports padding the same amount on the left and right sides,
|
||||
// and on the top and bottom sides. So we manually create a new padded
|
||||
// input tensor such that we can pass it to cuDNN.
|
||||
const int64 common_padding_rows =
|
||||
std::min(params.pad_top, params.pad_bottom);
|
||||
const int64 common_padding_cols =
|
||||
std::min(params.pad_left, params.pad_right);
|
||||
|
||||
Tensor padded_input;
|
||||
const int64 padding_rows_diff =
|
||||
std::abs(params.pad_top - params.pad_bottom);
|
||||
const int64 padding_cols_diff =
|
||||
std::abs(params.pad_left - params.pad_right);
|
||||
|
||||
const int64 new_in_rows = tensor_in_rows + padding_rows_diff;
|
||||
const int64 new_in_cols = tensor_in_cols + padding_cols_diff;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, batch_size,
|
||||
new_in_rows, new_in_cols, depth),
|
||||
&padded_input));
|
||||
const int64 input_pad_top = params.pad_top - common_padding_rows;
|
||||
const int64 input_pad_bottom = params.pad_bottom - common_padding_rows;
|
||||
const int64 input_pad_left = params.pad_left - common_padding_cols;
|
||||
const int64 input_pad_right = params.pad_right - common_padding_cols;
|
||||
|
||||
bool in_bounds =
|
||||
FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
|
||||
if (!in_bounds) {
|
||||
context->SetStatus(errors::InvalidArgument("Padding is too large."));
|
||||
return;
|
||||
}
|
||||
|
||||
// We need to call the const version of transformed_input.tensor()
|
||||
const Tensor& const_transformed_input = transformed_input;
|
||||
if (!std::is_same<T, qint8>::value) {
|
||||
T padding_value = -std::numeric_limits<T>::infinity();
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
context->eigen_device<GPUDevice>(),
|
||||
To32Bit(const_transformed_input.tensor<T, 4>()),
|
||||
{{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
|
||||
{{static_cast<int>(input_pad_bottom),
|
||||
static_cast<int>(input_pad_right)}},
|
||||
To32Bit(padded_input.tensor<T, 4>()), data_format, padding_value);
|
||||
} else {
|
||||
context->SetStatus(errors::InvalidArgument(
|
||||
"Explicit padding not yet supported with qint8"));
|
||||
return;
|
||||
}
|
||||
transformed_input = padded_input;
|
||||
vertical_padding = common_padding_rows;
|
||||
horizontal_padding = common_padding_cols;
|
||||
tensor_in_rows = new_in_rows;
|
||||
tensor_in_cols = new_in_cols;
|
||||
}
|
||||
|
||||
/// Get ready to call cudnn
|
||||
se::dnn::PoolingDescriptor pooling_desc;
|
||||
pooling_desc.set_pooling_mode(pooling_mode)
|
||||
.set_window_height(params.window_rows)
|
||||
.set_window_width(params.window_cols)
|
||||
.set_vertical_stride(params.row_stride)
|
||||
.set_horizontal_stride(params.col_stride)
|
||||
.set_vertical_padding(vertical_padding)
|
||||
.set_horizontal_padding(horizontal_padding)
|
||||
.set_vertical_padding(params.pad_rows)
|
||||
.set_horizontal_padding(params.pad_cols)
|
||||
.set_propagate_nans(propagate_nans);
|
||||
|
||||
se::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(batch_size)
|
||||
.set_height(tensor_in_rows)
|
||||
.set_width(tensor_in_cols)
|
||||
.set_height(params.tensor_in_rows)
|
||||
.set_width(params.tensor_in_cols)
|
||||
.set_feature_map_count(depth)
|
||||
.set_layout(data_layout);
|
||||
|
||||
@ -393,32 +280,13 @@ void DnnPoolingOp<T>::Compute(OpKernelContext* context,
|
||||
#endif
|
||||
}
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
|
||||
T padding_value); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
} // namespace functor
|
||||
|
||||
template <typename T>
|
||||
void DnnPoolingGradOp<T>::Compute(
|
||||
OpKernelContext* context, se::dnn::PoolingMode pooling_mode,
|
||||
const std::vector<int32>& size, const std::vector<int32>& stride,
|
||||
Padding padding, std::vector<int64> explicit_paddings,
|
||||
TensorFormat data_format, const Tensor* tensor_in, const Tensor* tensor_out,
|
||||
const Tensor& out_backprop, const TensorShape& tensor_in_shape,
|
||||
bool propagate_nans) {
|
||||
Padding padding, TensorFormat data_format, const Tensor* tensor_in,
|
||||
const Tensor* tensor_out, const Tensor& out_backprop,
|
||||
const TensorShape& tensor_in_shape, bool propagate_nans) {
|
||||
CHECK((pooling_mode != se::dnn::PoolingMode::kMaximum) ||
|
||||
(tensor_in && tensor_out))
|
||||
<< "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
|
||||
@ -431,8 +299,8 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
return;
|
||||
}
|
||||
|
||||
PoolParameters params{context, size, stride, padding,
|
||||
explicit_paddings, data_format, tensor_in_shape};
|
||||
PoolParameters params{context, size, stride,
|
||||
padding, data_format, tensor_in_shape};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -538,106 +406,6 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
}
|
||||
#endif // CUDNN_VERSION < 7300
|
||||
|
||||
int64 vertical_padding = params.pad_top;
|
||||
int64 horizontal_padding = params.pad_left;
|
||||
|
||||
int batch_size = params.tensor_in_batch;
|
||||
int depth = params.depth;
|
||||
int tensor_in_cols = params.tensor_in_cols;
|
||||
int tensor_in_rows = params.tensor_in_rows;
|
||||
|
||||
int64 input_pad_top = 0;
|
||||
int64 input_pad_bottom = 0;
|
||||
int64 input_pad_left = 0;
|
||||
int64 input_pad_right = 0;
|
||||
|
||||
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
|
||||
params.pad_left != params.pad_right)) {
|
||||
// Pad the input in the same way we did during the forward pass, so that
|
||||
// cuDNN or MIOpen receives the same input during the backward pass function
|
||||
// as it did during the forward pass function.
|
||||
const int64 common_padding_rows =
|
||||
std::min(params.pad_top, params.pad_bottom);
|
||||
const int64 common_padding_cols =
|
||||
std::min(params.pad_left, params.pad_right);
|
||||
|
||||
Tensor padded_input;
|
||||
Tensor padded_input_backprop;
|
||||
const int64 padding_rows_diff =
|
||||
std::abs(params.pad_top - params.pad_bottom);
|
||||
const int64 padding_cols_diff =
|
||||
std::abs(params.pad_left - params.pad_right);
|
||||
|
||||
const int64 new_in_rows = tensor_in_rows + padding_rows_diff;
|
||||
const int64 new_in_cols = tensor_in_cols + padding_cols_diff;
|
||||
|
||||
VLOG(2) << "Create new tensor: "
|
||||
<< " original rows=" << tensor_in_rows
|
||||
<< " original cols=" << tensor_in_cols
|
||||
<< " padding_rows=" << new_in_rows
|
||||
<< " padding_cols=" << new_in_cols << " depth= " << depth
|
||||
<< " batch_size=" << batch_size << " kernel_rows"
|
||||
<< params.window_rows << " kernel_col" << params.window_cols
|
||||
<< " stride_rows" << params.row_stride;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, batch_size,
|
||||
new_in_rows, new_in_cols, depth),
|
||||
&padded_input));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, batch_size,
|
||||
new_in_rows, new_in_cols, depth),
|
||||
&transformed_input_backprop));
|
||||
|
||||
input_pad_top = params.pad_top - common_padding_rows;
|
||||
input_pad_bottom = params.pad_bottom - common_padding_rows;
|
||||
input_pad_left = params.pad_left - common_padding_cols;
|
||||
input_pad_right = params.pad_right - common_padding_cols;
|
||||
|
||||
bool in_bounds =
|
||||
FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
|
||||
FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
|
||||
if (!in_bounds) {
|
||||
context->SetStatus(errors::InvalidArgument("Padding is too large."));
|
||||
return;
|
||||
}
|
||||
|
||||
// PadInput functor requires input to be a const.
|
||||
|
||||
const Tensor& const_transformed_input = transformed_input;
|
||||
|
||||
if (!std::is_same<T, qint8>::value) {
|
||||
T padding_value = -std::numeric_limits<T>::infinity();
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
context->eigen_device<GPUDevice>(),
|
||||
To32Bit(const_transformed_input.tensor<T, 4>()),
|
||||
{{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
|
||||
{{static_cast<int>(input_pad_bottom),
|
||||
static_cast<int>(input_pad_right)}},
|
||||
To32Bit(padded_input.tensor<T, 4>()), data_format, padding_value);
|
||||
} else {
|
||||
context->SetStatus(errors::InvalidArgument(
|
||||
"Explicit padding not yet supported with qint8"));
|
||||
return;
|
||||
}
|
||||
|
||||
transformed_input = padded_input;
|
||||
|
||||
vertical_padding = common_padding_rows;
|
||||
horizontal_padding = common_padding_cols;
|
||||
VLOG(2) << "vertical padding set to: " << vertical_padding
|
||||
<< " horizontal padding set to: " << horizontal_padding;
|
||||
tensor_in_rows = new_in_rows;
|
||||
tensor_in_cols = new_in_cols;
|
||||
}
|
||||
|
||||
/// Get ready to call cudnn
|
||||
se::dnn::PoolingDescriptor pooling_desc;
|
||||
pooling_desc.set_pooling_mode(pooling_mode)
|
||||
@ -645,8 +413,8 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
.set_window_width(params.window_cols)
|
||||
.set_vertical_stride(params.row_stride)
|
||||
.set_horizontal_stride(params.col_stride)
|
||||
.set_vertical_padding(vertical_padding)
|
||||
.set_horizontal_padding(horizontal_padding)
|
||||
.set_vertical_padding(params.pad_rows)
|
||||
.set_horizontal_padding(params.pad_cols)
|
||||
.set_propagate_nans(propagate_nans);
|
||||
|
||||
se::dnn::BatchDescriptor orig_output_desc;
|
||||
@ -658,8 +426,8 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
|
||||
se::dnn::BatchDescriptor orig_input_desc;
|
||||
orig_input_desc.set_count(params.tensor_in_batch)
|
||||
.set_height(tensor_in_rows)
|
||||
.set_width(tensor_in_cols)
|
||||
.set_height(params.tensor_in_rows)
|
||||
.set_width(params.tensor_in_cols)
|
||||
.set_feature_map_count(params.depth)
|
||||
.set_layout(data_layout);
|
||||
|
||||
@ -714,25 +482,6 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
input_backprop->tensor<T, 4>());
|
||||
}
|
||||
#endif // CUDNN_VERSION < 7300
|
||||
if (padding == EXPLICIT && (params.pad_top != params.pad_bottom ||
|
||||
params.pad_left != params.pad_right)) {
|
||||
// Remove the padding that was added to the input shape above.
|
||||
if (!std::is_same<T, qint8>::value) {
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
context->eigen_device<GPUDevice>(),
|
||||
To32Bit(const_cast<const Tensor&>(transformed_input_backprop)
|
||||
.tensor<T, 4>()),
|
||||
{{static_cast<int>(-input_pad_top),
|
||||
static_cast<int>(-input_pad_left)}},
|
||||
{{static_cast<int>(-input_pad_bottom),
|
||||
static_cast<int>(-input_pad_right)}},
|
||||
To32Bit(input_backprop->tensor<T, 4>()), data_format);
|
||||
} else {
|
||||
context->SetStatus(errors::InvalidArgument(
|
||||
"Explicit padding not yet supported with qint8"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_DNN_OPS(T) \
|
||||
|
@ -18,12 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -45,12 +40,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// A helper class to manage sizes and shapes for pooling operations.
|
||||
struct PoolParameters {
|
||||
// Updates context->status if there is an invalid input.
|
||||
// explicit_paddings has eight elements if padding==EXPLIICT, and zero
|
||||
// elements otherwise.
|
||||
PoolParameters(OpKernelContext* context, const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
std::vector<int64> explicit_paddings, TensorFormat data_format,
|
||||
const TensorShape& tensor_in_shape);
|
||||
TensorFormat data_format, const TensorShape& tensor_in_shape);
|
||||
|
||||
// Returns the shape of the output for "forward" pooling operations.
|
||||
TensorShape forward_output_shape();
|
||||
@ -73,21 +65,13 @@ struct PoolParameters {
|
||||
int64 out_width;
|
||||
int out_depth;
|
||||
|
||||
int64 pad_top;
|
||||
int64 pad_bottom;
|
||||
int64 pad_left;
|
||||
int64 pad_right;
|
||||
|
||||
int64 pad_rows;
|
||||
int64 pad_cols;
|
||||
int pad_depth;
|
||||
|
||||
TensorFormat data_format;
|
||||
};
|
||||
|
||||
// Checks if the sizes of the paddings are less than the size of window.
|
||||
// This is required for MaxPool because it pads with -inf, so the pooling
|
||||
// window cannot fully cover the padded area.
|
||||
Status CheckPaddingSize(PoolParameters& params);
|
||||
|
||||
// An implementation of MaxPooling (forward).
|
||||
// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
|
||||
// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
|
||||
@ -122,10 +106,6 @@ class MaxPoolingOp : public OpKernel {
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
if (padding_ == Padding::EXPLICIT) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("explicit_paddings", &explicit_paddings_));
|
||||
}
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
@ -133,9 +113,8 @@ class MaxPoolingOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
PoolParameters params{
|
||||
context, ksize_, stride_, padding_, explicit_paddings_,
|
||||
FORMAT_NHWC, tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -155,21 +134,9 @@ class MaxPoolingOp : public OpKernel {
|
||||
context, params.depth_window == params.depth_stride,
|
||||
errors::Unimplemented("Depthwise max pooling requires "
|
||||
"the depth window to equal the depth stride."));
|
||||
OP_REQUIRES(
|
||||
context, padding_ != EXPLICIT,
|
||||
errors::Unimplemented("Depthwise max pooling does not support "
|
||||
"explicit padding."));
|
||||
|
||||
DepthwiseMaxPool(context, output, tensor_in, params);
|
||||
} else {
|
||||
// MaxPoolingOp is only called on the GPU when the eigen_tensor label
|
||||
// is used. In this case, explicit padding is not supported
|
||||
if (std::is_same<Device, GPUDevice>::value &&
|
||||
padding_ == Padding::EXPLICIT) {
|
||||
context->SetStatus(errors::Unimplemented(
|
||||
"MaxPoolingOp does not support explicit padding."));
|
||||
return;
|
||||
}
|
||||
SpatialMaxPool(context, output, tensor_in, params, padding_);
|
||||
}
|
||||
}
|
||||
@ -235,8 +202,8 @@ class MaxPoolingOp : public OpKernel {
|
||||
auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) {
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
const int32 pad_top = params.pad_top;
|
||||
const int32 pad_left = params.pad_left;
|
||||
const int32 pad_rows = params.pad_rows;
|
||||
const int32 pad_cols = params.pad_cols;
|
||||
const int32 window_rows = params.window_rows;
|
||||
const int32 window_cols = params.window_cols;
|
||||
const int32 row_stride = params.row_stride;
|
||||
@ -258,8 +225,8 @@ class MaxPoolingOp : public OpKernel {
|
||||
for (int32 w = 0; w < in_cols; ++w) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
const int32 hpad = h + pad_top;
|
||||
const int32 wpad = w + pad_left;
|
||||
const int32 hpad = h + pad_rows;
|
||||
const int32 wpad = w + pad_cols;
|
||||
const int32 h_start = (hpad < window_rows)
|
||||
? 0
|
||||
: (hpad - window_rows) / row_stride + 1;
|
||||
@ -296,7 +263,6 @@ class MaxPoolingOp : public OpKernel {
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
std::vector<int64> explicit_paddings_;
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
@ -314,7 +280,7 @@ struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
|
||||
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
|
||||
params.depth, params.out_height, params.out_width, params.window_rows,
|
||||
params.window_cols, params.row_stride, params.col_stride,
|
||||
params.pad_top, params.pad_left,
|
||||
params.pad_rows, params.pad_cols,
|
||||
reinterpret_cast<int32*>(output->flat<qint8>().data()),
|
||||
context->eigen_gpu_device());
|
||||
if (!status) {
|
||||
@ -392,15 +358,8 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{
|
||||
context,
|
||||
ksize,
|
||||
stride,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
data_format_,
|
||||
tensor_in.shape(),
|
||||
};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -496,8 +455,8 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) {
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
const int32 pad_top = params.pad_top;
|
||||
const int32 pad_left = params.pad_left;
|
||||
const int32 pad_rows = params.pad_rows;
|
||||
const int32 pad_cols = params.pad_cols;
|
||||
const int32 window_rows = params.window_rows;
|
||||
const int32 window_cols = params.window_cols;
|
||||
const int32 row_stride = params.row_stride;
|
||||
@ -519,8 +478,8 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
for (int32 w = 0; w < in_cols; ++w) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
const int32 hpad = h + pad_top;
|
||||
const int32 wpad = w + pad_left;
|
||||
const int32 hpad = h + pad_rows;
|
||||
const int32 wpad = w + pad_cols;
|
||||
const int32 h_start = (hpad < window_rows)
|
||||
? 0
|
||||
: (hpad - window_rows) / row_stride + 1;
|
||||
@ -608,8 +567,8 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
|
||||
for (int w = 0; w < params.tensor_in_cols; ++w) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
const int hpad = h + params.pad_top;
|
||||
const int wpad = w + params.pad_left;
|
||||
const int hpad = h + params.pad_rows;
|
||||
const int wpad = w + params.pad_cols;
|
||||
const int h_start =
|
||||
(hpad < params.window_rows)
|
||||
? 0
|
||||
|
@ -43,7 +43,6 @@ class DnnPoolingOp {
|
||||
se::dnn::PoolingMode pooling_mode,
|
||||
const std::vector<int32>& size,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
std::vector<int64> explicit_paddings,
|
||||
TensorFormat data_format, const Tensor& tensor_in,
|
||||
const TensorShape& tensor_out_shape, bool propagate_nans);
|
||||
};
|
||||
@ -59,7 +58,6 @@ class DnnPoolingGradOp {
|
||||
se::dnn::PoolingMode pooling_mode,
|
||||
const std::vector<int32>& size,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
std::vector<int64> explicit_paddings,
|
||||
TensorFormat data_format, const Tensor* tensor_in,
|
||||
const Tensor* tensor_out, const Tensor& out_backprop,
|
||||
const TensorShape& tensor_in_shape, bool propagate_nans);
|
||||
|
@ -54,13 +54,8 @@ class QuantizedAvgPoolingOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
PoolParameters params{context,
|
||||
ksize_,
|
||||
stride_,
|
||||
padding_,
|
||||
/*explicit_paddings=*/{},
|
||||
FORMAT_NHWC,
|
||||
tensor_in.shape()};
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
@ -841,12 +841,11 @@ REGISTER_OP("MaxPool")
|
||||
"uint16, qint8} = DT_FLOAT")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrStringWithExplicit())
|
||||
.Attr(GetExplicitPaddingsAttrString())
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding);
|
||||
.SetShapeFn(shape_inference::MaxPoolShape);
|
||||
|
||||
REGISTER_OP("MaxPoolV2")
|
||||
.Attr(
|
||||
@ -866,8 +865,7 @@ REGISTER_OP("MaxPoolV2")
|
||||
REGISTER_OP("MaxPoolGrad")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrStringWithExplicit())
|
||||
.Attr(GetExplicitPaddingsAttrString())
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("orig_input: T")
|
||||
.Input("orig_output: T")
|
||||
@ -892,7 +890,6 @@ REGISTER_OP("MaxPoolGradV2")
|
||||
return UnchangedShapeWithRank(c, 4);
|
||||
});
|
||||
|
||||
// TODO(b/150813181): Implement explicit padding.
|
||||
REGISTER_OP("MaxPoolGradGrad")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
|
@ -337,13 +337,13 @@ def NHWCToNCHW(input_tensor):
|
||||
"""Converts the input from the NHWC format to NCHW.
|
||||
|
||||
Args:
|
||||
input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
|
||||
input_tensor: a 4- or 5-D tensor, or an array representing shape
|
||||
|
||||
Returns:
|
||||
converted tensor or shape array
|
||||
"""
|
||||
# tensor dim -> new axis order
|
||||
new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
|
||||
new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
|
||||
if isinstance(input_tensor, ops.Tensor):
|
||||
ndims = input_tensor.shape.ndims
|
||||
return array_ops.transpose(input_tensor, new_axes[ndims])
|
||||
|
@ -50,21 +50,15 @@ def GetDeviceScope(self, use_gpu=False):
|
||||
return self.session(use_gpu=use_gpu)
|
||||
|
||||
|
||||
def GetTestConfigs(include_nchw_vect_c=False, one_dimensional=False):
|
||||
def GetTestConfigs(include_nchw_vect_c=False):
|
||||
"""Get all the valid tests configs to run.
|
||||
|
||||
Args:
|
||||
include_nchw_vect_c: Whether to include NCHW_VECT_C in the test configs.
|
||||
one_dimensional: If it's a 1D test
|
||||
|
||||
Returns:
|
||||
all the valid test configs as tuples of data_format and use_gpu.
|
||||
"""
|
||||
if one_dimensional:
|
||||
test_configs = [("NWC", False), ("NWC", True)]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
test_configs += [("NCW", True)]
|
||||
return test_configs
|
||||
test_configs = [("NHWC", False), ("NHWC", True)]
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
tf_logging.info("NCHW and NCHW_VECT_C tests skipped because not run with "
|
||||
@ -112,12 +106,8 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
|
||||
|
||||
class PoolingTest(test.TestCase):
|
||||
|
||||
def _isMaxPool(self, func):
|
||||
return func in (nn_ops.max_pool, nn_ops.max_pool_v2)
|
||||
|
||||
def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, data_type, expected, use_gpu, v2,
|
||||
use_negative_input=False):
|
||||
data_format, data_type, expected, use_gpu, v2):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -131,8 +121,6 @@ class PoolingTest(test.TestCase):
|
||||
data_type: The data type to use to run the pooling operation.
|
||||
expected: An array containing the expected operation outputs.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
v2: Whether to use v2 version.
|
||||
use_negative_input: If the input values should be negative.
|
||||
"""
|
||||
total_size = 1
|
||||
for s in input_sizes:
|
||||
@ -153,11 +141,10 @@ class PoolingTest(test.TestCase):
|
||||
data_type)
|
||||
# Initializes the input tensor with array containing incrementing
|
||||
# numbers from 1, wrapping round to -127 after 127 to support int8.
|
||||
y = -1 if use_negative_input else 1
|
||||
x = [(((f + 128) % 255) - 127)*y for f in range(total_size)]
|
||||
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
t = constant_op.constant(x, shape=input_sizes, dtype=data_type)
|
||||
if data_format in ("NCHW", "NCHW_VECT_C", "NCW"):
|
||||
if data_format in ("NCHW", "NCHW_VECT_C"):
|
||||
if data_format == "NCHW_VECT_C":
|
||||
t = test_util.NHWCToNCHW_VECT_C(t)
|
||||
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
||||
@ -165,8 +152,6 @@ class PoolingTest(test.TestCase):
|
||||
t = test_util.NHWCToNCHW(t)
|
||||
ksize = test_util.NHWCToNCHW(ksize)
|
||||
strides = test_util.NHWCToNCHW(strides)
|
||||
if isinstance(padding, list):
|
||||
padding = test_util.NHWCToNCHW(padding)
|
||||
ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
if v2:
|
||||
@ -199,8 +184,7 @@ class PoolingTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(expected, actual.flatten())
|
||||
|
||||
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected, use_gpu, v2,
|
||||
use_negative_input=False):
|
||||
data_format, expected, use_gpu, v2):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -213,8 +197,6 @@ class PoolingTest(test.TestCase):
|
||||
data_format: The data format we use to run the pooling operation.
|
||||
expected: An array containing the expected operation outputs.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
v2: Whether to use v2 version.
|
||||
use_negative_input: If the input values should be negative."
|
||||
"""
|
||||
if data_format == "NCHW_VECT_C":
|
||||
avg_pool_func = nn_ops.avg_pool
|
||||
@ -222,24 +204,17 @@ class PoolingTest(test.TestCase):
|
||||
if pool_func == avg_pool_func:
|
||||
tf_logging.info("NCHW_VECT_C not yet implemented for avg_pool")
|
||||
return
|
||||
if (self._isMaxPool(pool_func) and isinstance(padding, list)):
|
||||
tf_logging.info("NCHW_VECT_C not yet implemented for max pool" +
|
||||
" with explicit padding")
|
||||
return
|
||||
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float32, expected, use_gpu, v2,
|
||||
use_negative_input)
|
||||
data_format, dtypes.float32, expected, use_gpu, v2)
|
||||
if not test.is_built_with_rocm():
|
||||
# double datatype is not supported for pooling ops on the ROCm platform
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float64, expected, use_gpu, v2,
|
||||
use_negative_input)
|
||||
data_format, dtypes.float64, expected, use_gpu, v2)
|
||||
|
||||
if not use_gpu or test_util.GpuSupportsHalfMatMulAndConv():
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float16, expected, use_gpu, v2,
|
||||
use_negative_input)
|
||||
data_format, dtypes.float16, expected, use_gpu, v2)
|
||||
|
||||
def _VerifyValues(self,
|
||||
pool_func,
|
||||
@ -249,9 +224,7 @@ class PoolingTest(test.TestCase):
|
||||
padding,
|
||||
expected,
|
||||
use_gpu,
|
||||
v2=False,
|
||||
one_dim=False,
|
||||
use_negative_input=False):
|
||||
v2=False):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -263,16 +236,11 @@ class PoolingTest(test.TestCase):
|
||||
padding: Padding type.
|
||||
expected: An array containing the expected operation outputs.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
v2: Whether to use v2 version.
|
||||
one_dim: If one dimensional pools should be done instead of two
|
||||
dimensional pools.
|
||||
use_negative_input: If the input values should be negative.
|
||||
"""
|
||||
for (data_format, use_gpu_2) in GetTestConfigs(True, one_dim):
|
||||
for (data_format, use_gpu_2) in GetTestConfigs(True):
|
||||
if use_gpu_2 == use_gpu:
|
||||
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected, use_gpu, v2,
|
||||
use_negative_input)
|
||||
data_format, expected, use_gpu, v2)
|
||||
|
||||
def _testAvgPoolValidPadding(self, use_gpu):
|
||||
expected_output = [7.0, 8.0, 9.0]
|
||||
@ -499,101 +467,6 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolZeroExplicitPadding(self, use_gpu):
|
||||
expected_output = [9.0]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolNegativeInputExpPadding(self, use_gpu):
|
||||
expected_output = [-1, -1, -1, -1]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [2, 1], [2, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
use_negative_input=True)
|
||||
|
||||
def _testMaxPoolExplicitPadding(self, use_gpu):
|
||||
expected_output = [9.0, 9.0]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [0, 2], [0, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPaddingAdvanced(self, use_gpu):
|
||||
expected_output = [7, 9, 11, 12, 19, 21, 23, 24, 31, 33, 35, 36, 31, 33,
|
||||
35, 36]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 6, 6, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [1, 2], [2, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolNegativeInputExpPaddingAdv(self, use_gpu):
|
||||
expected_output = [-1, -1, -3, -5, -7, -7, -9, -11, -19, -19, -21, -23, -31,
|
||||
-31, -33, -35]
|
||||
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 6, 6, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [1, 2], [2, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
use_negative_input=True)
|
||||
|
||||
def _testMaxPoolExplicitPaddingV2(self, use_gpu):
|
||||
expected_output = [9.0, 9.0]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool_v2,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [0, 2], [0, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPadding1D(self, use_gpu):
|
||||
expected_output = [2.0, 3.0]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool1d,
|
||||
input_sizes=[1, 3, 1],
|
||||
ksize=[1, 2, 1],
|
||||
strides=[1, 2, 1],
|
||||
padding=[[0, 0], [0, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
one_dim=True)
|
||||
|
||||
def _testMaxPoolExplicitPadding1dV2(self, use_gpu):
|
||||
expected_output = [2.0, 3.0]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool_v2,
|
||||
input_sizes=[1, 3, 1],
|
||||
ksize=[1, 2, 1],
|
||||
strides=[1, 2, 1],
|
||||
padding=[[0, 0], [0, 1], [0, 0]],
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
one_dim=True)
|
||||
|
||||
def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
|
||||
# input is:
|
||||
# [1.0, 2.0
|
||||
@ -745,14 +618,6 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolSamePaddingPacket4(use_gpu)
|
||||
self._testMaxPoolSamePaddingPacket8(use_gpu)
|
||||
self._testMaxPoolEmptyInput(use_gpu)
|
||||
self._testMaxPoolZeroExplicitPadding(use_gpu)
|
||||
self._testMaxPoolExplicitPadding(use_gpu)
|
||||
self._testMaxPoolExplicitPaddingV2(use_gpu)
|
||||
self._testMaxPoolExplicitPadding1D(use_gpu)
|
||||
self._testMaxPoolExplicitPadding1dV2(use_gpu)
|
||||
self._testMaxPoolExplicitPaddingAdvanced(use_gpu)
|
||||
self._testMaxPoolNegativeInputExpPadding(use_gpu)
|
||||
self._testMaxPoolNegativeInputExpPaddingAdv(use_gpu)
|
||||
|
||||
# Tests for DepthwiseMaxPooling on CPU only.
|
||||
@test_util.run_deprecated_v1
|
||||
@ -1115,7 +980,7 @@ class PoolingTest(test.TestCase):
|
||||
data_format,
|
||||
use_gpu,
|
||||
x_init_value=None):
|
||||
"""Verifies the gradients of the max or avg pooling function.
|
||||
"""Verifies the gradients of the avg pooling function.
|
||||
|
||||
Args:
|
||||
pool_func: Function to be called, co.MaxPool, co.AvgPool,
|
||||
@ -1152,13 +1017,11 @@ class PoolingTest(test.TestCase):
|
||||
func_name = "max_pool"
|
||||
err_tolerance = 1e-3
|
||||
if data_format == "NCHW":
|
||||
ksize = [1, 1, window_rows, window_cols]
|
||||
ksize = [1, 1, window_rows, window_rows]
|
||||
strides = [1, 1, row_stride, col_stride]
|
||||
if isinstance(padding, list):
|
||||
padding = test_util.NHWCToNCHW(padding)
|
||||
t = test_util.NHWCToNCHW(input_tensor)
|
||||
else:
|
||||
ksize = [1, window_rows, window_cols, 1]
|
||||
ksize = [1, window_rows, window_rows, 1]
|
||||
strides = [1, row_stride, col_stride, 1]
|
||||
t = input_tensor
|
||||
t = pool_func(
|
||||
@ -1398,76 +1261,6 @@ class PoolingTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPadding1(self, data_format, use_gpu):
|
||||
for pool_func in [nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPadding2(self, data_format, use_gpu):
|
||||
for pool_func in [nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 6, 8, 1],
|
||||
window_rows=3,
|
||||
window_cols=5,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 1], [2, 3], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPaddingLeftGreater(self, data_format, use_gpu):
|
||||
for pool_func in [nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 6, 8, 1],
|
||||
window_rows=3,
|
||||
window_cols=5,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 1], [3, 2], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPaddingBatchChannel(self, data_format, use_gpu):
|
||||
for pool_func in [nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[4, 7, 7, 3],
|
||||
output_sizes=[4, 6, 8, 3],
|
||||
window_rows=3,
|
||||
window_cols=5,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 1], [3, 2], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolExplicitPaddingStrides(self, data_format, use_gpu):
|
||||
for pool_func in [nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 4, 3, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=2,
|
||||
col_stride=3,
|
||||
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMaxPoolGrad(self):
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
@ -1481,11 +1274,6 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolGradSamePadding2_1(data_format, use_gpu)
|
||||
self._testMaxPoolGradSamePadding2_2(data_format, use_gpu)
|
||||
self._testMaxPoolGradSamePadding3_1(data_format, use_gpu)
|
||||
self._testMaxPoolExplicitPadding1(data_format, use_gpu)
|
||||
self._testMaxPoolExplicitPadding2(data_format, use_gpu)
|
||||
self._testMaxPoolExplicitPaddingStrides(data_format, use_gpu)
|
||||
self._testMaxPoolExplicitPaddingLeftGreater(data_format, use_gpu)
|
||||
self._testMaxPoolExplicitPaddingBatchChannel(data_format, use_gpu)
|
||||
|
||||
def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
|
||||
window_cols, row_stride, col_stride, padding, v2):
|
||||
@ -1506,16 +1294,9 @@ class PoolingTest(test.TestCase):
|
||||
A Tensor.
|
||||
"""
|
||||
pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops.max_pool_grad
|
||||
if v2:
|
||||
return pool_func(orig_input, orig_output, grad,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
else:
|
||||
padding, explicit_paddings = nn_ops.convert_padding(padding)
|
||||
return pool_func(orig_input, orig_output, grad,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding,
|
||||
explicit_paddings)
|
||||
return pool_func(orig_input, orig_output, grad,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
|
||||
def _testMaxPoolGradDirect(self, input_data, output_backprop,
|
||||
expected_input_backprop, input_sizes, output_sizes,
|
||||
@ -1658,116 +1439,6 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradZeroExplicitPadding(self):
|
||||
input_data = [
|
||||
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, 1.0
|
||||
]
|
||||
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
|
||||
expected_input_backprop = [
|
||||
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 17.0, 19.0, 0.0, 41.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
for v2 in [False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradExplicitPadding_1(self):
|
||||
input_data = [
|
||||
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, 1.0
|
||||
]
|
||||
output_backprop = [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
|
||||
20.0, 21.0, 22.0]
|
||||
expected_input_backprop = [
|
||||
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 49.0, 19.0, 0.0, 41.0, 0.0, 0.0,
|
||||
0.0, 0.0, 22.0
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
for v2 in [False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 0], [0, 1], [0, 0]],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradExplicitPadding_2(self):
|
||||
input_data = [
|
||||
1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, 1.0
|
||||
]
|
||||
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
|
||||
expected_input_backprop = [
|
||||
54.0, 0.0, 30.0, 0.0, 0.0, 0.0, 0.0, 0.0, 39.0, 0.0, 21.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
for v2 in [False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding=[[0, 0], [2, 1], [2, 1], [0, 0]],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradExplicitPadding_3(self):
|
||||
input_data = [
|
||||
-1.0, -5.0, -1.0, -5.0, -5.0, -1.0, -5.0, -1.0, -1.0, -5.0, -1.0, -5.0,
|
||||
-5.0, -1.0, -5.0, -1.0
|
||||
]
|
||||
output_backprop = [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
|
||||
20.0, 21.0, 22.0]
|
||||
expected_input_backprop = [
|
||||
11.0, 0.0, 25.0, 0.0, 0.0, 31.0, 0.0, 49.0, 19.0, 0.0, 41.0, 0.0, 0.0,
|
||||
0.0, 0.0, 22.0
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
for v2 in [False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 4, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding=[[0, 0], [0, 0], [0, 1], [0, 0]],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
@test_util.no_xla_auto_jit("b/123923733") # NaNs handled differently
|
||||
def _testMaxPoolGradDirectWithNans2_1(self):
|
||||
input_data = [float("nan")] * 16
|
||||
@ -1944,10 +1615,6 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolGradDirect1_3()
|
||||
self._testMaxPoolGradDirectWithNans2_1()
|
||||
self._testMaxPoolGradDirectWithNans2_2()
|
||||
self._testMaxPoolGradZeroExplicitPadding()
|
||||
self._testMaxPoolGradExplicitPadding_1()
|
||||
self._testMaxPoolGradExplicitPadding_2()
|
||||
self._testMaxPoolGradExplicitPadding_3()
|
||||
|
||||
def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu):
|
||||
for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]:
|
||||
@ -2289,94 +1956,6 @@ class PoolingTest(test.TestCase):
|
||||
strides=[1, 1, 1, 1],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def _testEdgeCasesRaiseErrors(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Data formats NCHW_VECT_C is not yet supported with "
|
||||
"explicit padding"):
|
||||
nn_ops.max_pool(
|
||||
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1]),
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding=[[0, 0], [0, 1], [0, 1], [0, 0]],
|
||||
data_format="NCHW_VECT_C")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Explicit padding is not yet supported with an input "
|
||||
"tensor of rank 5"):
|
||||
nn_ops.max_pool_v2(
|
||||
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1, 1]),
|
||||
ksize=[1, 2, 2, 1, 1],
|
||||
strides=[1, 2, 2, 1, 1],
|
||||
padding=[[0, 0], [0, 1], [0, 1], [0, 0]],
|
||||
data_format="NCHW")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Attr 'padding' of 'MaxPoolV2' Op passed "
|
||||
"string 'EXPLICIT'"):
|
||||
gen_nn_ops.max_pool_v2(
|
||||
array_ops.placeholder(dtypes.float32, shape=[1, 3, 3, 1, 1]),
|
||||
ksize=[1, 2, 2, 1, 1],
|
||||
strides=[1, 2, 2, 1, 1],
|
||||
padding="EXPLICIT",
|
||||
data_format="NHWC")
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def _testEdgeCasesExcessPadding(self):
|
||||
with self.session(use_gpu=test.is_gpu_available()) as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"Right padding 2 needs to be smaller than the window size 2"):
|
||||
input_sizes = [1, 3, 3, 1]
|
||||
x = [(((f + 128) % 255) - 127) for f in range(9)]
|
||||
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
|
||||
sess.run(gen_nn_ops.max_pool(
|
||||
t,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="EXPLICIT",
|
||||
explicit_paddings=[0, 0, 0, 1, 0, 2, 0, 0],
|
||||
data_format="NHWC"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def _testNegativePadding(self):
|
||||
with self.session(use_gpu=test.is_gpu_available()) as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "All elements of explicit_paddings must be "
|
||||
"nonnegative for"):
|
||||
input_sizes = [1, 3, 3, 1]
|
||||
x = [(((f + 128) % 255) - 127) for f in range(9)]
|
||||
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
|
||||
sess.run(gen_nn_ops.max_pool(
|
||||
t,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="EXPLICIT",
|
||||
explicit_paddings=[0, 0, -1, -1, -1, -1, 0, 0],
|
||||
data_format="NHWC"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def _testExplicitPaddingBatch(self):
|
||||
with self.session(use_gpu=test.is_gpu_available()) as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Nonzero explicit padding in the batch or depth "
|
||||
"dimensions is not supported"):
|
||||
input_sizes = [1, 3, 3, 1]
|
||||
x = [(((f + 128) % 255) - 127) for f in range(9)]
|
||||
t = constant_op.constant(x, shape=input_sizes, dtype=dtypes.float32)
|
||||
sess.run(gen_nn_ops.max_pool(
|
||||
t,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="EXPLICIT",
|
||||
explicit_paddings=[1, 1, 1, 1, 1, 1, 0, 0],
|
||||
data_format="NHWC"))
|
||||
|
||||
def testExplicitPaddingEdgeCases(self):
|
||||
# Tests for Explicit padding.
|
||||
self._testEdgeCasesRaiseErrors()
|
||||
self._testEdgeCasesExcessPadding()
|
||||
self._testExplicitPaddingBatch()
|
||||
self._testNegativePadding()
|
||||
|
||||
|
||||
def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
|
||||
|
||||
|
@ -688,7 +688,6 @@ def _MaxPoolGrad(op, grad):
|
||||
op.get_attr("ksize"),
|
||||
op.get_attr("strides"),
|
||||
padding=op.get_attr("padding"),
|
||||
explicit_paddings=op.get_attr("explicit_paddings"),
|
||||
data_format=op.get_attr("data_format"))
|
||||
|
||||
|
||||
|
@ -1752,14 +1752,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
|
||||
name=name)
|
||||
|
||||
|
||||
def convert_padding(padding, expected_length=4):
|
||||
def convert_padding(padding):
|
||||
"""Converts Python padding to C++ padding for ops which take EXPLICIT padding.
|
||||
|
||||
Args:
|
||||
padding: the `padding` argument for a Python op which supports EXPLICIT
|
||||
padding.
|
||||
expected_length: Expected number of entries in the padding list when
|
||||
explicit padding is used.
|
||||
|
||||
Returns:
|
||||
(padding, explicit_paddings) pair, which should be passed as attributes to a
|
||||
@ -1785,9 +1783,9 @@ def convert_padding(padding, expected_length=4):
|
||||
"be a list/tuple of size 2. Element with index %d of "
|
||||
"padding has size %d" % (i, len(dim_paddings)))
|
||||
explicit_paddings.extend(dim_paddings)
|
||||
if len(padding) != expected_length:
|
||||
raise ValueError("When padding is a list, it must be of size %d. Got "
|
||||
"padding of size: %d" % (expected_length, len(padding)))
|
||||
if len(padding) != 4:
|
||||
raise ValueError("When padding is a list, it must be of size 4. Got "
|
||||
"padding of size: %d" % len(padding))
|
||||
padding = "EXPLICIT"
|
||||
return padding, explicit_paddings
|
||||
|
||||
@ -4483,15 +4481,8 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
|
||||
of the window for each dimension of the input tensor.
|
||||
strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
|
||||
stride of the sliding window for each dimension of the input tensor.
|
||||
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
|
||||
padding algorithm to use, or a list indicating the explicit paddings at
|
||||
the start and end of each dimension. When explicit padding is used and
|
||||
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
|
||||
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
|
||||
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
|
||||
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
|
||||
padding, the size of the paddings cannot be greater than the sliding
|
||||
window size.
|
||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
|
||||
the "returns" section of `tf.nn.convolution` for details.
|
||||
data_format: A string. Specifies the channel dimension. For N=1 it can be
|
||||
either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
|
||||
or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
|
||||
@ -4517,20 +4508,12 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
|
||||
else:
|
||||
channel_index = 1 if data_format.startswith("NC") else n + 1
|
||||
|
||||
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
|
||||
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
|
||||
"explicit padding")
|
||||
|
||||
ksize = _get_sequence(ksize, n, channel_index, "ksize")
|
||||
strides = _get_sequence(strides, n, channel_index, "strides")
|
||||
|
||||
if (isinstance(padding, (list, tuple)) and n == 3):
|
||||
raise ValueError("Explicit padding is not yet supported with an input "
|
||||
"tensor of rank 5")
|
||||
|
||||
max_pooling_ops = {
|
||||
1: max_pool1d,
|
||||
2: max_pool2d,
|
||||
2: gen_nn_ops.max_pool,
|
||||
3: gen_nn_ops.max_pool3d
|
||||
}
|
||||
|
||||
@ -4562,15 +4545,8 @@ def max_pool(value,
|
||||
The size of the window for each dimension of the input tensor.
|
||||
strides: An int or list of `ints` that has length `1`, `2` or `4`.
|
||||
The stride of the sliding window for each dimension of the input tensor.
|
||||
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
|
||||
padding algorithm to use, or a list indicating the explicit paddings at
|
||||
the start and end of each dimension. When explicit padding is used and
|
||||
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
|
||||
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
|
||||
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
|
||||
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
|
||||
padding, the size of the paddings cannot be greater than the sliding
|
||||
window size.
|
||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
|
||||
See the "returns" section of `tf.nn.convolution` for details.
|
||||
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
|
||||
name: Optional name for the operation.
|
||||
input: Alias for value.
|
||||
@ -4587,10 +4563,6 @@ def max_pool(value,
|
||||
|
||||
ksize = _get_sequence(ksize, 2, channel_index, "ksize")
|
||||
strides = _get_sequence(strides, 2, channel_index, "strides")
|
||||
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
|
||||
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
|
||||
"explicit padding")
|
||||
padding, explicit_paddings = convert_padding(padding)
|
||||
if ((np.isscalar(ksize) and ksize == 0) or
|
||||
(isinstance(ksize,
|
||||
(list, tuple, np.ndarray)) and any(v == 0 for v in ksize))):
|
||||
@ -4601,7 +4573,6 @@ def max_pool(value,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
explicit_paddings=explicit_paddings,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
|
||||
@ -4620,14 +4591,8 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
|
||||
window for each dimension of the input tensor.
|
||||
strides: An int or list of `ints` that has length `1` or `3`. The stride of
|
||||
the sliding window for each dimension of the input tensor.
|
||||
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
|
||||
padding algorithm to use, or a list indicating the explicit paddings at
|
||||
the start and end of each dimension. When explicit padding is used and
|
||||
data_format is `"NWC"`, this should be in the form `[[0, 0], [pad_left,
|
||||
pad_right], [0, 0]]`. When explicit padding used and data_format is
|
||||
`"NCW"`, this should be in the form `[[0, 0], [0, 0], [pad_left,
|
||||
pad_right]]`. When using explicit padding, the size of the paddings cannot
|
||||
be greater than the sliding window size.
|
||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
|
||||
the "returns" section of `tf.nn.convolution` for details.
|
||||
data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
|
||||
name: A name for the operation (optional).
|
||||
|
||||
@ -4636,17 +4601,11 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
|
||||
The max pooled output tensor.
|
||||
"""
|
||||
with ops.name_scope(name, "MaxPool1d", [input]) as name:
|
||||
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
|
||||
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
|
||||
"explicit padding")
|
||||
if data_format is None:
|
||||
data_format = "NWC"
|
||||
channel_index = 1 if data_format.startswith("NC") else 2
|
||||
ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
|
||||
strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
|
||||
padding, explicit_paddings = convert_padding(padding, 3)
|
||||
if padding == "EXPLICIT":
|
||||
explicit_paddings = [0, 0] + explicit_paddings
|
||||
|
||||
expanding_dim = 1 if data_format == "NWC" else 2
|
||||
data_format = "NHWC" if data_format == "NWC" else "NCHW"
|
||||
@ -4657,7 +4616,6 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
explicit_paddings=explicit_paddings,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
return array_ops.squeeze(result, expanding_dim)
|
||||
@ -4676,15 +4634,8 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
|
||||
the window for each dimension of the input tensor.
|
||||
strides: An int or list of `ints` that has length `1`, `2` or `4`. The
|
||||
stride of the sliding window for each dimension of the input tensor.
|
||||
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
|
||||
padding algorithm to use, or a list indicating the explicit paddings at
|
||||
the start and end of each dimension. When explicit padding is used and
|
||||
data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
|
||||
pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
|
||||
and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
|
||||
[pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
|
||||
padding, the size of the paddings cannot be greater than the sliding
|
||||
window size.
|
||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
|
||||
the "returns" section of `tf.nn.convolution` for details.
|
||||
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
|
||||
name: Optional name for the operation.
|
||||
|
||||
@ -4699,17 +4650,12 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
|
||||
|
||||
ksize = _get_sequence(ksize, 2, channel_index, "ksize")
|
||||
strides = _get_sequence(strides, 2, channel_index, "strides")
|
||||
if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
|
||||
raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
|
||||
"explicit padding")
|
||||
padding, explicit_paddings = convert_padding(padding)
|
||||
|
||||
return gen_nn_ops.max_pool(
|
||||
input,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
explicit_paddings=explicit_paddings,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
# pylint: enable=redefined-builtin
|
||||
|
@ -2422,7 +2422,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPool"
|
||||
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPool3D"
|
||||
@ -2438,7 +2438,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPoolGrad"
|
||||
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
|
||||
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPoolGradGrad"
|
||||
|
@ -2422,7 +2422,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPool"
|
||||
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPool3D"
|
||||
@ -2438,7 +2438,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPoolGrad"
|
||||
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'None\'], "
|
||||
argspec: "args=[\'orig_input\', \'orig_output\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "MaxPoolGradGrad"
|
||||
|
Loading…
Reference in New Issue
Block a user