diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 9df5cbdec06..bd5d6e4af4e 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -673,6 +673,116 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + std::vector<int32> kernel_sizes; + std::vector<int32> strides; + + if (c->num_inputs() + 2 == num_inputs) { + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + } else { + // Verify shape of ksize and strides input. + ShapeHandle size; + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); + + const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); + if (kernel_sizes_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); + auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); + std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin()); + + const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); + if (strides_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + strides.resize(strides_tensor->shape().num_elements()); + auto strides_vec = strides_tensor->flat<int32>(); + std::copy_n(&strides_vec(0), strides.size(), strides.begin()); + } + + if (strides.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the stride attribute to contain 4 values, but " + "got: ", + strides.size()); + } + if (kernel_sizes.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the ksize attribute to contain 4 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_rows, stride_cols, stride_depth; + int32 kernel_rows, kernel_cols, kernel_depth; + + if (s.ok() && data_format == "NCHW") { + // Canonicalize input shape to NHWC so the shape inference code below can + // process it. + auto dim = [&](char dimension) { + return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); + }; + input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); + stride_depth = strides[1]; + stride_rows = strides[2]; + stride_cols = strides[3]; + kernel_depth = kernel_sizes[1]; + kernel_rows = kernel_sizes[2]; + kernel_cols = kernel_sizes[3]; + } else { + stride_rows = strides[1]; + stride_cols = strides[2]; + stride_depth = strides[3]; + kernel_rows = kernel_sizes[1]; + kernel_cols = kernel_sizes[2]; + kernel_depth = kernel_sizes[3]; + } + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_rows_dim = c->Dim(input_shape, 1); + DimensionHandle in_cols_dim = c->Dim(input_shape, 2); + DimensionHandle in_depth_dim = c->Dim(input_shape, 3); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + ShapeHandle output_shape; + DimensionHandle output_rows, output_cols, output_depth; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); + + output_shape = + c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); + if (data_format == "NCHW") { + // Convert output shape back to expected NCHW data format. + auto dim = [&](char dimension) { + return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension)); + }; + output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}}); + } + + c->set_output(0, output_shape); + return Status::OK(); +} + Status Pool3DShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 73b915652f6..fb79df07a4f 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -179,6 +179,9 @@ Status AvgPoolShape(shape_inference::InferenceContext* c); // Shape function for MaxPool-like operations. Status MaxPoolShape(shape_inference::InferenceContext* c); +// Shape function for MaxPoolV2-like operations. +Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); + // Shape function for 3D Pooling operations. Status Pool3DShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 6cb56797bff..8d825c13d76 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -208,22 +208,26 @@ class MaxPoolingGradOp : public OpKernel { errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ", "on device type ", DeviceTypeString(context->device_type()))); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + + if (context->num_inputs() == 3) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES( + context, ksize_[3] == 1 && stride_[3] == 1, + errors::Unimplemented( + "MaxPoolingGrad is not yet supported on the depth dimension.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented( - "Pooling is not yet supported on the batch dimension.")); - OP_REQUIRES( - context, ksize_[3] == 1 && stride_[3] == 1, - errors::Unimplemented( - "MaxPoolingGrad is not yet supported on the depth dimension.")); } void Compute(OpKernelContext* context) override { @@ -250,8 +254,35 @@ class MaxPoolingGradOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(), tensor_out.shape(), &tensor_out_arg_max)); + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + if (context->num_inputs() == 5) { + const Tensor& tensor_ksize = context->input(3); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); - PoolParameters params{context, ksize_, stride_, + const Tensor& tensor_stride = context->input(4); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES( + context, ksize[3] == 1 && stride[3] == 1, + errors::Unimplemented( + "MaxPoolingGrad is not yet supported on the depth dimension.")); + + PoolParameters params{context, ksize, stride, padding_, FORMAT_NHWC, tensor_in.shape()}; if (!context->status().ok()) { return; @@ -309,20 +340,22 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + if (context->num_inputs() == 3) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride_, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N'); - const int32 stride_n = GetTensorDim(stride_, data_format_, 'N'); - OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, - errors::Unimplemented( - "Pooling is not yet supported on the batch dimension.")); use_dnn_ = CanUseCudnn(); } @@ -343,15 +376,40 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel { TensorShape output_shape = tensor_in.shape(); + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + if (context->num_inputs() == 5) { + const Tensor& tensor_ksize = context->input(3); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(4); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + if (use_dnn_) { DnnPoolingGradOp<T>::Compute( - context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_, - stride_, padding_, data_format_, &tensor_in, &tensor_out, - out_backprop, output_shape); + context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize, + stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop, + output_shape); } else { CHECK(data_format_ == FORMAT_NHWC) << "Non-Cudnn MaxPoolGrad only supports NHWC format"; - MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_, + MaxPoolingBackwardCustomKernel<T>(context, ksize, stride, padding_, &tensor_in, out_backprop, output_shape); } } @@ -386,22 +444,25 @@ class MaxPoolingGradGradOp : public OpKernel { errors::InvalidArgument( "Default MaxPoolingGradGradOp only supports NHWC ", "on device type ", DeviceTypeString(context->device_type()))); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented( - "Pooling is not yet supported on the batch dimension.")); - OP_REQUIRES( - context, ksize_[3] == 1 && stride_[3] == 1, - errors::Unimplemented( - "MaxPoolingGradGrad is not yet supported on the depth dimension.")); + + if (context->num_inputs() == 3) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES(context, ksize_[3] == 1 && stride_[3] == 1, + errors::Unimplemented("MaxPoolingGradGrad is not yet " + "supported on the depth dimension.")); + } } void Compute(OpKernelContext* context) override { @@ -419,7 +480,35 @@ class MaxPoolingGradGradOp : public OpKernel { context, out_grad_backprop.dims() == 4, errors::InvalidArgument("out_grad_backprop must be 4-dimensional")); - PoolParameters params{context, ksize_, stride_, + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + if (context->num_inputs() == 5) { + const Tensor& tensor_ksize = context->input(3); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(4); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES( + context, ksize[3] == 1 && stride[3] == 1, + errors::Unimplemented( + "MaxPoolingGrad is not yet supported on the depth dimension.")); + + PoolParameters params{context, ksize, stride, padding_, FORMAT_NHWC, tensor_in.shape()}; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( @@ -474,7 +563,7 @@ class MaxPoolingGradGradOp : public OpKernel { // tensor_out_as_matrix with the corresponding values in // top_diff_as_matrix. auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat]( - int64 start, int64 limit) { + int64 start, int64 limit) { const int32 depth = params.depth; const int32 in_rows = params.tensor_in_rows; const int32 in_cols = params.tensor_in_cols; @@ -555,20 +644,22 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + if (context->num_inputs() == 3) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride_, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N'); - const int32 stride_n = GetTensorDim(stride_, data_format_, 'N'); - OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, - errors::Unimplemented( - "Pooling is not yet supported on the batch dimension.")); } void Compute(OpKernelContext* context) override { @@ -590,7 +681,33 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, tensor_out.shape(), &output)); - PoolParameters params{context, ksize_, stride_, + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + if (context->num_inputs() == 5) { + const Tensor& tensor_ksize = context->input(3); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(4); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + + PoolParameters params{context, ksize, stride, padding_, data_format_, tensor_in.shape()}; functor::MaxPoolGradBackwardNoMask<T>()( @@ -669,6 +786,84 @@ class MaxPoolingNoMaskOp : public OpKernel { TensorFormat data_format_; }; +template <typename Device, typename T> +class MaxPoolingNoMaskV2Op : public OpKernel { + public: + explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Default MaxPoolingNoMaskOp only supports NHWC on device type ", + DeviceTypeString(context->device_type()))); + if (context->num_inputs() == 1) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + + if (context->num_inputs() != 1) { + const Tensor& tensor_ksize = context->input(1); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(2); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + PoolParameters params{context, ksize, stride, + padding_, data_format_, tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + TensorShape out_shape({params.tensor_in_batch, params.out_height, + params.out_width, params.depth}); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in, + output); + } + + private: + std::vector<int32> ksize_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; +}; + template <typename Device, typename T> struct LaunchMaxPoolingWithArgmax; @@ -878,6 +1073,95 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel { bool use_dnn_; }; +template <typename T> +class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel { + public: + typedef GPUDevice Device; + explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + if (context->num_inputs() == 1) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride_, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + use_dnn_ = CanUseCudnn(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + + if (context->num_inputs() != 1) { + const Tensor& tensor_ksize = context->input(1); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(2); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N'); + const int32 stride_n = GetTensorDim(stride, data_format_, 'N'); + OP_REQUIRES(context, ksize_n == 1 && stride_n == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + + PoolParameters params{context, ksize, stride, + padding_, data_format_, tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + TensorShape out_shape = + ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height, + params.out_width, params.depth); + if (use_dnn_ && data_format_ == FORMAT_NCHW) { + DnnPoolingOp<T>::Compute( + context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize, + stride, padding_, data_format_, tensor_in, out_shape); + } else { + CHECK(data_format_ == FORMAT_NHWC) + << "Non-Cudnn MaxPool only supports NHWC format"; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in, + output); + } + } + + private: + std::vector<int32> ksize_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool use_dnn_; +}; + template <typename T> struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> { static void launch(OpKernelContext* context, const PoolParameters& params, @@ -969,13 +1253,28 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> { MaxPoolingGradOp<D##Device, T>); \ REGISTER_KERNEL_BUILDER( \ Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ - MaxPoolingGradGradOp<D##Device, T>); + MaxPoolingGradGradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolGradV2") \ + .Device(DEVICE_##D) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint<T>("T"), \ + MaxPoolingGradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradV2") \ + .Device(DEVICE_##D) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint<T>("T"), \ + MaxPoolingGradGradOp<D##Device, T>); // Below kernels implemented only for CPU device. -#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - MaxPoolingOp<CPUDevice, T>); +#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + MaxPoolingOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + MaxPoolingV2Op<CPUDevice, T>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS); #undef REGISTER_CPU_ONLY_POOL_KERNELS @@ -1015,9 +1314,22 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS); .TypeConstraint<T>("T") \ .Label("eigen_tensor"), \ MaxPoolingOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint<T>("T") \ + .Label("eigen_tensor"), \ + MaxPoolingV2Op<GPUDevice, T>); \ REGISTER_KERNEL_BUILDER( \ Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ MaxPoolingNoMaskOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint<T>("T"), \ + MaxPoolingNoMaskV2Op<GPUDevice, T>); \ REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \ .Device(DEVICE_GPU) \ .TypeConstraint<int64>("Targmax") \ diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 2c097c0ce24..1b59c18df79 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -69,6 +69,8 @@ struct PoolParameters { }; // An implementation of MaxPooling (forward). +// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op, +// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now template <typename Device, typename T> class MaxPoolingOp : public OpKernel { public: @@ -254,6 +256,219 @@ class MaxPoolingOp : public OpKernel { TensorFormat data_format_; }; +template <typename Device, typename T> +class MaxPoolingV2Op : public OpKernel { + public: + explicit MaxPoolingV2Op(OpKernelConstruction* context) : OpKernel(context) { + string data_format; + auto status = context->GetAttr("data_format", &data_format); + if (status.ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Default MaxPoolingOp only supports NHWC.")); + } else { + data_format_ = FORMAT_NHWC; + } + if (context->num_inputs() == 1) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + std::vector<int32> ksize = ksize_; + std::vector<int32> stride = stride_; + + if (context->num_inputs() != 1) { + const Tensor& tensor_ksize = context->input(1); + auto value_ksize = tensor_ksize.flat<int32>(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(2); + auto value_stride = tensor_stride.flat<int32>(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + + PoolParameters params{context, ksize, stride, + padding_, FORMAT_NHWC, tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, params.forward_output_shape(), &output)); + + if (params.depth_window > 1) { + // Validate spec against the current implementation. A + // relaxation of these requirements would be ideal. + OP_REQUIRES(context, params.depth % params.depth_window == 0, + errors::Unimplemented( + "Depthwise max pooling requires " + "the depth window to evenly divide the input depth.")); + OP_REQUIRES( + context, params.depth_window == params.depth_stride, + errors::Unimplemented("Depthwise max pooling requires " + "the depth window to equal the depth stride.")); + + DepthwiseMaxPool(context, output, tensor_in, params); + } else { + SpatialMaxPool(context, output, tensor_in, params, padding_); + } + } + + private: + // Single-threaded implementation of DepthwiseMaxPool which + // does not handle all of the same options as SpatialMaxPool + // (strict assumptions on no padding, stride). + // + // TODO(vrv): implement a more general depthwise-max pool that works + // on GPU as well. + void DepthwiseMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params) { + Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + in_by_pool(tensor_in.flat<T>().data(), params.depth_window, + tensor_in.NumElements() / params.depth_window); + Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool( + output->flat<T>().data(), 1, output->NumElements()); + out_by_pool = in_by_pool.colwise().maxCoeff(); + } + + void SpatialMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params, + const Padding& padding) { + // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an + // EigenMatrix version that is currently faster than Eigen's + // Spatial MaxPooling implementation. + // + // TODO(vrv): Remove this once we no longer need it. + if (std::is_same<Device, GPUDevice>::value) { + Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); + functor::SpatialMaxPooling<Device, T>()( + context->eigen_device<Device>(), output->tensor<T, 4>(), + tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } else { + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + + ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * + params.tensor_in_batch); + EigenMatrixMap out_mat( + output->flat<T>().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // max value. + auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) { + + const int32 in_rows = params.tensor_in_rows; + const int32 in_cols = params.tensor_in_cols; + const int32 pad_rows = params.pad_rows; + const int32 pad_cols = params.pad_cols; + const int32 window_rows = params.window_rows; + const int32 window_cols = params.window_cols; + const int32 row_stride = params.row_stride; + const int32 col_stride = params.col_stride; + const int32 out_height = params.out_height; + const int32 out_width = params.out_width; + + { + // Initializes the output tensor with MIN<T>. + const int32 output_image_size = out_height * out_width * params.depth; + EigenMatrixMap out_shard(out_mat.data() + start * output_image_size, + 1, (limit - start) * output_image_size); + out_shard.setConstant(Eigen::NumTraits<T>::lowest()); + } + + for (int32 b = start; b < limit; ++b) { + const int32 out_offset_batch = b * out_height; + for (int32 h = 0; h < in_rows; ++h) { + for (int32 w = 0; w < in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int32 hpad = h + pad_rows; + const int32 wpad = w + pad_cols; + const int32 h_start = (hpad < window_rows) + ? 0 + : (hpad - window_rows) / row_stride + 1; + const int32 h_end = std::min(hpad / row_stride + 1, out_height); + const int32 w_start = (wpad < window_cols) + ? 0 + : (wpad - window_cols) / col_stride + 1; + const int32 w_end = std::min(wpad / col_stride + 1, out_width); + // compute elementwise max + const int32 in_offset = (b * in_rows + h) * in_cols + w; + for (int32 ph = h_start; ph < h_end; ++ph) { + const int32 out_offset_base = + (out_offset_batch + ph) * out_width; + for (int32 pw = w_start; pw < w_end; ++pw) { + const int32 out_offset = out_offset_base + pw; + out_mat.col(out_offset) = + out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); + } + } + } + } + } + }; + + // TODO(andydavis) Consider sharding across batch x rows x cols. + // TODO(andydavis) Consider a higher resolution shard cost model. + const int64 shard_cost = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + Shard(worker_threads.num_threads, worker_threads.workers, + params.tensor_in_batch, shard_cost, shard); + } + } + + std::vector<int32> ksize_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; +}; + template <typename Device, typename T> void SpatialAvgPool(OpKernelContext* context, Tensor* output, const Tensor& input, const PoolParameters& params, diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 10187425214..0a96258dd1f 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1368,6 +1368,34 @@ input: 4-D input to pool over. output: The max pooled output tensor. )doc"); +REGISTER_OP("MaxPoolV2") + .Attr("T: realnumbertype = DT_FLOAT") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("input: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3)); + return Status::OK(); + }) + .Doc(R"doc( +Performs max pooling on the input. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +input: 4-D input to pool over. +output: The max pooled output tensor. +)doc"); + REGISTER_OP("MaxPoolGrad") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -1399,6 +1427,37 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`. output: Gradients w.r.t. the input to `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradV2") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .Attr("T: realnumbertype = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); + }) + .Doc(R"doc( +Computes gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients w.r.t. the output of `max_pool`. +output: Gradients w.r.t. the input to `max_pool`. +)doc"); + REGISTER_OP("MaxPoolGradGrad") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -1436,6 +1495,43 @@ grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. output: Gradients of gradients w.r.t. the input to `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradGradV2") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5)); + ShapeHandle unused; + // Validate 'orig_input' is the same shape as 'grad' + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); + // Validate 'orig_output' is same shape as 'output' + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Computes second-order gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +output: Gradients of gradients w.r.t. the input to `max_pool`. +)doc"); + REGISTER_OP("MaxPoolWithArgmax") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index f5fb7e4e03e..da14871c872 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops +from tensorflow.python.framework import ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -76,7 +77,7 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30): class PoolingTest(test.TestCase): def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding, - data_format, data_type, expected, use_gpu): + data_format, data_type, expected, use_gpu, v2): """Verifies the output values of the pooling function. Args: @@ -103,20 +104,35 @@ class PoolingTest(test.TestCase): t = test_util.NHWCToNCHW(t) ksize = test_util.NHWCToNCHW(ksize) strides = test_util.NHWCToNCHW(strides) - t = pool_func( - t, - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format) + v2 = v2 and data_format != "NCHW" + ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) + strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) + if v2: + t = pool_func( + t, + ksize=ksize_placeholder, + strides=strides_placeholder, + padding=padding, + data_format=data_format) + else: + t = pool_func( + t, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format) if data_format == "NCHW": t = test_util.NCHWToNHWC(t) - actual = t.eval() + if v2: + actual = t.eval(feed_dict={ksize_placeholder: ksize, + strides_placeholder: strides}) + else: + actual = t.eval() + self.assertShapeEqual(actual, t) self.assertAllCloseAccordingToType(expected, actual.flatten()) - self.assertShapeEqual(actual, t) def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, - data_format, expected, use_gpu): + data_format, expected, use_gpu, v2): """Verifies the output values of the pooling function. Args: @@ -131,14 +147,14 @@ class PoolingTest(test.TestCase): use_gpu: Whether we are running on GPU. """ self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, - data_format, dtypes.float32, expected, use_gpu) + data_format, dtypes.float32, expected, use_gpu, v2) if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv(): self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, - data_format, dtypes.float16, expected, use_gpu) + data_format, dtypes.float16, expected, use_gpu, v2) def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding, - expected, use_gpu): + expected, use_gpu, v2=False): """Verifies the output values of the pooling function. Args: @@ -154,7 +170,7 @@ class PoolingTest(test.TestCase): for (data_format, use_gpu_2) in GetTestConfigs(): if use_gpu_2 == use_gpu: self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding, - data_format, expected, use_gpu) + data_format, expected, use_gpu, v2) def _testAvgPoolValidPadding(self, use_gpu): expected_output = [7.0, 8.0, 9.0] @@ -325,6 +341,17 @@ class PoolingTest(test.TestCase): expected=expected_output, use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 3, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="VALID", + expected=expected_output, + use_gpu=use_gpu, + v2=v2) + def _testMaxPoolSamePadding(self, use_gpu): expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0] self._VerifyValues( @@ -336,6 +363,17 @@ class PoolingTest(test.TestCase): expected=expected_output, use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 2, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output, + use_gpu=use_gpu, + v2=v2) + def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu): # input is: # [1.0, 2.0 @@ -354,6 +392,17 @@ class PoolingTest(test.TestCase): expected=[2.0, 2.0, 4.0, 4.0], use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 2, 2, 1], + ksize=[1, 1, 2, 1], + strides=[1, 1, 1, 1], + padding="SAME", + expected=[2.0, 2.0, 4.0, 4.0], + use_gpu=use_gpu, + v2=v2) + def _testMaxPoolValidPaddingUnevenStride(self, use_gpu): self._VerifyValues( nn_ops.max_pool, @@ -372,6 +421,26 @@ class PoolingTest(test.TestCase): expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0], use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 1, 2, 1], + padding="VALID", + expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0], + use_gpu=use_gpu, + v2=v2) + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 2, 1, 1], + padding="VALID", + expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0], + use_gpu=use_gpu, + v2=v2) + def _testMaxPoolSamePaddingPacket4(self, use_gpu): expected_output = [ 21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0, @@ -386,6 +455,17 @@ class PoolingTest(test.TestCase): expected=expected_output, use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 4, 4, 4], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output, + use_gpu=use_gpu, + v2=v2) + def _testMaxPoolSamePaddingPacket8(self, use_gpu): expected_output = [ 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0, @@ -411,6 +491,17 @@ class PoolingTest(test.TestCase): expected=expected_output, use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 8, 8, 8], + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output, + use_gpu=use_gpu, + v2=v2) + def testMaxPooling(self): for use_gpu in True, False: self._testMaxPoolValidPadding(use_gpu) @@ -435,6 +526,17 @@ class PoolingTest(test.TestCase): expected=[2.0, 4.0, 6.0, 8.0, 10.0], use_gpu=False) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 1, 1, 10], + ksize=[1, 1, 1, 2], + strides=[1, 1, 1, 2], + padding="SAME", + expected=[2.0, 4.0, 6.0, 8.0, 10.0], + use_gpu=False, + v2=v2) + def testDepthwiseMaxPool2x2DepthWindow3(self): # input is: # @@ -450,6 +552,17 @@ class PoolingTest(test.TestCase): expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0], use_gpu=False) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 2, 2, 6], + ksize=[1, 1, 1, 3], + strides=[1, 1, 1, 3], + padding="SAME", + expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0], + use_gpu=False, + v2=v2) + def testKernelSmallerThanStrideValid(self): for use_gpu in [True, False]: self._VerifyValues( @@ -461,6 +574,17 @@ class PoolingTest(test.TestCase): expected=[9, 12, 30, 33], use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 7, 7, 1], + ksize=[1, 2, 2, 1], + strides=[1, 3, 3, 1], + padding="VALID", + expected=[9, 12, 30, 33], + use_gpu=use_gpu, + v2=v2) + self._VerifyValues( nn_ops.avg_pool, input_sizes=[1, 7, 7, 1], @@ -491,6 +615,27 @@ class PoolingTest(test.TestCase): expected=[1, 3, 9, 11], use_gpu=use_gpu) + for v2 in [True, False]: + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 3, 3, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 7, 9], + use_gpu=use_gpu, + v2=v2) + + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[1, 4, 4, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 9, 11], + use_gpu=use_gpu, + v2=v2) + def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, @@ -812,99 +957,107 @@ class PoolingTest(test.TestCase): self.assertLess(err, err_tolerance) def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[1, 3, 3, 1], - output_sizes=[1, 3, 3, 1], - window_rows=1, - window_cols=1, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[1, 3, 3, 1], + output_sizes=[1, 3, 3, 1], + window_rows=1, + window_cols=1, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_1_6(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 6, 6, 3], - output_sizes=[2, 5, 5, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 6, 6, 3], + output_sizes=[2, 5, 5, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_1_7(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 7, 7, 3], - output_sizes=[2, 6, 6, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 7, 7, 3], + output_sizes=[2, 6, 6, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_2(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 2, 3], - output_sizes=[2, 1, 1, 3], - window_rows=2, - window_cols=2, - row_stride=2, - col_stride=2, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 2, 2, 3], + output_sizes=[2, 1, 1, 3], + window_rows=2, + window_cols=2, + row_stride=2, + col_stride=2, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradSamePadding1_1(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 2, 4, 3], - window_rows=1, - window_cols=1, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 2, 4, 3], + window_rows=1, + window_cols=1, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradSamePadding2_1(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 2, 4, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 2, 4, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradSamePadding2_2(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 1, 2, 3], - window_rows=2, - window_cols=2, - row_stride=2, - col_stride=2, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 1, 2, 3], + window_rows=2, + window_cols=2, + row_stride=2, + col_stride=2, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu): - self._ConstructAndTestGradient( - nn_ops.max_pool, + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestGradient( + pool_func, input_sizes=[1, 7, 7, 1], output_sizes=[1, 7, 7, 1], window_rows=3, @@ -927,7 +1080,7 @@ class PoolingTest(test.TestCase): self._testMaxPoolGradSamePadding3_1(data_format, use_gpu) def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows, - window_cols, row_stride, col_stride, padding): + window_cols, row_stride, col_stride, padding, v2): """Max Pooling Gradient. Args: @@ -944,26 +1097,29 @@ class PoolingTest(test.TestCase): Returns: A Tensor. """ - return gen_nn_ops._max_pool_grad(orig_input, orig_output, grad, - [1, window_rows, window_cols, 1], - [1, row_stride, col_stride, 1], padding) + pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops._max_pool_grad + return pool_func(orig_input, orig_output, grad, + [1, window_rows, window_cols, 1], + [1, row_stride, col_stride, 1], padding) def _testMaxPoolGradDirect(self, input_data, output_backprop, expected_input_backprop, input_sizes, output_sizes, window_rows, window_cols, row_stride, col_stride, - padding, use_gpu): + padding, use_gpu, v2): + pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool with self.test_session(use_gpu=use_gpu): input_tensor = constant_op.constant(input_data, shape=input_sizes) - output_tensor = nn_ops.max_pool(input_tensor, - [1, window_rows, window_cols, 1], - [1, row_stride, col_stride, 1], padding) + output_tensor = pool_func(input_tensor, + [1, window_rows, window_cols, 1], + [1, row_stride, col_stride, 1], padding) output_backprop_tensor = constant_op.constant( output_backprop, shape=output_sizes) input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor, output_backprop_tensor, window_rows, window_cols, - row_stride, col_stride, padding) + row_stride, col_stride, + padding, v2) actual_input_backprop = input_backprop_tensor.eval() self.assertShapeEqual(actual_input_backprop, input_backprop_tensor) @@ -988,18 +1144,20 @@ class PoolingTest(test.TestCase): ] for use_gpu in True, False: - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=use_gpu) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=use_gpu, + v2=v2) def _testMaxPoolGradDirect1_2(self): input_data = [ @@ -1013,18 +1171,20 @@ class PoolingTest(test.TestCase): ] for use_gpu in True, False: - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=use_gpu) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=use_gpu, + v2=v2) def _testMaxPoolGradDirect1_3(self): input_data = [ @@ -1069,18 +1229,20 @@ class PoolingTest(test.TestCase): ] for use_gpu in True, False: - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 4, 4, 1], - window_rows=3, - window_cols=3, - row_stride=1, - col_stride=1, - padding="SAME", - use_gpu=use_gpu) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 4, 4, 1], + window_rows=3, + window_cols=3, + row_stride=1, + col_stride=1, + padding="SAME", + use_gpu=use_gpu, + v2=v2) def _testMaxPoolGradDirectWithNans2_1(self): input_data = [float("nan")] * 16 @@ -1090,18 +1252,20 @@ class PoolingTest(test.TestCase): 11.0, 12.0, 13.0, 0.0, 15.0, 16.0, 17.0, 0.0, 19.0, 20.0, 21.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop_tf_cpu, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=False) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop_tf_cpu, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=False, + v2=v2) if not test.is_gpu_available(): return @@ -1112,18 +1276,20 @@ class PoolingTest(test.TestCase): 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop_cudnn, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=True) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop_cudnn, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=True, + v2=v2) def _testMaxPoolGradDirectWithNans2_2(self): input_data = [float("nan")] * 16 @@ -1136,18 +1302,20 @@ class PoolingTest(test.TestCase): float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0, 20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0 ] - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop_tf_cpu, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=False) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop_tf_cpu, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=False, + v2=v2) if not test.is_gpu_available(): return @@ -1158,18 +1326,20 @@ class PoolingTest(test.TestCase): 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] - self._testMaxPoolGradDirect( - input_data, - output_backprop, - expected_input_backprop_cudnn, - input_sizes=[1, 4, 4, 1], - output_sizes=[1, 3, 3, 1], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - use_gpu=True) + for v2 in [True, False]: + self._testMaxPoolGradDirect( + input_data, + output_backprop, + expected_input_backprop_cudnn, + input_sizes=[1, 4, 4, 1], + output_sizes=[1, 3, 3, 1], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + use_gpu=True, + v2=v2) def testMaxPoolGradDirect(self): self._testMaxPoolGradDirect1_1() @@ -1179,108 +1349,116 @@ class PoolingTest(test.TestCase): self._testMaxPoolGradDirectWithNans2_2() def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[1, 3, 3, 1], - output_sizes=[1, 3, 3, 1], - window_rows=1, - window_cols=1, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[1, 3, 3, 1], + output_sizes=[1, 3, 3, 1], + window_rows=1, + window_cols=1, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_1_6(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 6, 6, 3], - output_sizes=[2, 5, 5, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 6, 6, 3], + output_sizes=[2, 5, 5, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_1_7(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 7, 7, 3], - output_sizes=[2, 6, 6, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 7, 7, 3], + output_sizes=[2, 6, 6, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_2(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 2, 3], - output_sizes=[2, 1, 1, 3], - window_rows=2, - window_cols=2, - row_stride=2, - col_stride=2, - padding="VALID", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 2, 2, 3], + output_sizes=[2, 1, 1, 3], + window_rows=2, + window_cols=2, + row_stride=2, + col_stride=2, + padding="VALID", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding1_1(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 2, 4, 3], - window_rows=1, - window_cols=1, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 2, 4, 3], + window_rows=1, + window_cols=1, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding2_1(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 2, 4, 3], - window_rows=2, - window_cols=2, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 2, 4, 3], + window_rows=2, + window_cols=2, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding2_2(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[2, 2, 4, 3], - output_sizes=[2, 1, 2, 3], - window_rows=2, - window_cols=2, - row_stride=2, - col_stride=2, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[2, 2, 4, 3], + output_sizes=[2, 1, 2, 3], + window_rows=2, + window_cols=2, + row_stride=2, + col_stride=2, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding3_1(self, data_format, use_gpu): - self._ConstructAndTestSecondGradient( - nn_ops.max_pool, - input_sizes=[1, 7, 7, 1], - output_sizes=[1, 7, 7, 1], - window_rows=3, - window_cols=3, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + self._ConstructAndTestSecondGradient( + pool_func, + input_sizes=[1, 7, 7, 1], + output_sizes=[1, 7, 7, 1], + window_rows=3, + window_cols=3, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def testMaxPoolGradGrad(self): for (data_format, use_gpu) in GetTestConfigs(): diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index ffa5dc4d623..eeaf418c8b3 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -302,6 +302,7 @@ BiasAddV1 Relu6 AvgPool MaxPool +MaxPoolV2 Softmax LogSoftmax FractionalAvgPoolGrad diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 094757dac9f..de302a22712 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -541,6 +541,19 @@ def _MaxPoolGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("MaxPoolV2") +def _MaxPoolGradV2(op, grad): + ksize = op.inputs[1] + strides = op.inputs[2] + return gen_nn_ops.max_pool_grad_v2(op.inputs[0], + op.outputs[0], + grad, + ksize, + strides, + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format")), None, None + + @ops.RegisterGradient("MaxPoolWithArgmax") def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0], @@ -567,6 +580,24 @@ def _MaxPoolGradGrad(op, grad): data_format=op.get_attr("data_format"))) +@ops.RegisterGradient("MaxPoolGradV2") +def _MaxPoolGradGradV2(op, grad): + ksize = op.inputs[3] + strides = op.inputs[4] + return (array_ops.zeros( + shape=array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype), array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + gen_nn_ops.max_pool_grad_grad_v2( + op.inputs[0], + op.inputs[1], + grad, + ksize, + strides, + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format")), None, None) + + @ops.RegisterGradient("MaxPoolGradGrad") def _MaxPoolGradGradGrad(op, grad): return (array_ops.zeros(