From 8b667b7d4bf1b845ce2ccffad9221490afa583ce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 30 Aug 2016 13:23:06 -0800 Subject: [PATCH] Adds fractional_max_pool and fractional_avg_pool ops. Fixes #2953. Change: 131754627 --- tensorflow/core/kernels/BUILD | 4 + .../core/kernels/fractional_avg_pool_op.cc | 354 +++++++++++ .../core/kernels/fractional_max_pool_op.cc | 381 ++++++++++++ .../core/kernels/fractional_pool_common.cc | 134 ++++ .../core/kernels/fractional_pool_common.h | 78 +++ tensorflow/core/ops/nn_ops.cc | 201 ++++++ .../shard3/tf.nn.fractional_max_pool.md | 82 +++ .../shard7/tf.nn.fractional_avg_pool.md | 57 ++ tensorflow/g3doc/api_docs/python/index.md | 2 + tensorflow/g3doc/api_docs/python/nn.md | 145 +++++ tensorflow/python/kernel_tests/BUILD | 2 + .../fractional_avg_pool_op_test.py | 521 ++++++++++++++++ .../fractional_max_pool_op_test.py | 582 ++++++++++++++++++ tensorflow/python/ops/hidden_ops.txt | 2 + tensorflow/python/ops/nn.py | 2 + tensorflow/python/ops/nn_grad.py | 47 ++ tensorflow/python/ops/nn_ops.py | 33 + 17 files changed, 2627 insertions(+) create mode 100644 tensorflow/core/kernels/fractional_avg_pool_op.cc create mode 100644 tensorflow/core/kernels/fractional_max_pool_op.cc create mode 100644 tensorflow/core/kernels/fractional_pool_common.cc create mode 100644 tensorflow/core/kernels/fractional_pool_common.h create mode 100644 tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md create mode 100644 tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md create mode 100644 tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py create mode 100644 tensorflow/python/kernel_tests/fractional_max_pool_op_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index aeda266a70c..f279e24af79 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1448,6 +1448,9 @@ tf_kernel_library( srcs = [ "avgpooling_op.cc", "cudnn_pooling_gpu.cc", + "fractional_avg_pool_op.cc", + "fractional_max_pool_op.cc", + "fractional_pool_common.cc", "maxpooling_op.cc", "pooling_ops_3d.cc", "pooling_ops_common.cc", @@ -1455,6 +1458,7 @@ tf_kernel_library( hdrs = [ "avgpooling_op.h", "cudnn_pooling_gpu.h", + "fractional_pool_common.h", "maxpooling_op.h", "pooling_ops_common.h", ], diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc new file mode 100644 index 00000000000..a983d9362cc --- /dev/null +++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc @@ -0,0 +1,354 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include <algorithm> +#include <cmath> +#include <random> +#include <vector> + +#include "tensorflow/core/kernels/fractional_pool_common.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <typename T> +class FractionalAvgPoolOp : public OpKernel { + public: + explicit FractionalAvgPoolOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_)); + OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_)); + OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_)); + OP_REQUIRES(context, pooling_ratio_.size() == 4, + errors::InvalidArgument( + "pooling_ratio field must specify 4 dimensions")); + OP_REQUIRES( + context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1, + errors::Unimplemented("Fractional average pooling is not yet " + "supported on the batch nor channel dimension.")); + OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_)); + pooling_region_generated_ = false; + // Initialize philox random generator. + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + + constexpr int tensor_in_and_out_dims = 4; + + const Tensor& tensor_in = context->input(0); + OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + input_size_.push_back(tensor_in.dim_size(i)); + } + // Output size. + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + output_size_.push_back( + static_cast<int>(floor(input_size_[i] / pooling_ratio_[i]))); + DCHECK_GT(output_size_[i], 0); + } + + // Generate pooling sequence. + std::vector<int64> row_cum_seq; + std::vector<int64> col_cum_seq; + if (deterministic_) { + if (pooling_region_generated_) { + row_cum_seq = row_cum_seq_; + col_cum_seq = col_cum_seq_; + } else { + row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], + &generator_, pseudo_random_); + col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], + &generator_, pseudo_random_); + mutex_lock lock(mu_); + row_cum_seq_ = row_cum_seq; + col_cum_seq_ = col_cum_seq; + pooling_region_generated_ = true; + } + } else { + row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], + &generator_, pseudo_random_); + col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], + &generator_, pseudo_random_); + } + + // Prepare output. + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({output_size_[0], output_size_[1], + output_size_[2], output_size_[3]}), + &output_tensor)); + Tensor* output_row_seq_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 1, TensorShape({static_cast<int64>(row_cum_seq.size())}), + &output_row_seq_tensor)); + Tensor* output_col_seq_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 2, TensorShape({static_cast<int64>(col_cum_seq.size())}), + &output_col_seq_tensor)); + + ConstEigenMatrixMap in_mat( + tensor_in.flat<T>().data(), input_size_[3], + input_size_[2] * input_size_[1] * input_size_[0]); + + EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3], + output_size_[2] * output_size_[1] * output_size_[0]); + // out_count corresponds to number of elements in each pooling cell. + Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols()); + + // Initializes the output tensor and out_count with 0. + out_mat.setZero(); + out_count.setZero(); + + auto output_row_seq_flat = output_row_seq_tensor->flat<int64>(); + auto output_col_seq_flat = output_col_seq_tensor->flat<int64>(); + + // Set output tensors. + for (int i = 0; i < row_cum_seq.size(); ++i) { + output_row_seq_flat(i) = row_cum_seq[i]; + } + + for (int i = 0; i < col_cum_seq.size(); ++i) { + output_col_seq_flat(i) = col_cum_seq[i]; + } + + // For both input and output, + // 0: batch + // 1: row / row + // 2: col / col + // 3: depth / channel + const int64 row_max = input_size_[1] - 1; + const int64 col_max = input_size_[2] - 1; + for (int64 b = 0; b < input_size_[0]; ++b) { + // row sequence. + for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) { + // row start and end. + const int64 row_start = row_cum_seq[hs]; + int64 row_end = + overlapping_ ? row_cum_seq[hs + 1] : row_cum_seq[hs + 1] - 1; + row_end = std::min(row_end, row_max); + + // col sequence. + for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) { + const int64 out_offset = + (b * output_size_[1] + hs) * output_size_[2] + ws; + // col start and end. + const int64 col_start = col_cum_seq[ws]; + int64 col_end = + overlapping_ ? col_cum_seq[ws + 1] : col_cum_seq[ws + 1] - 1; + col_end = std::min(col_end, col_max); + for (int64 h = row_start; h <= row_end; ++h) { + for (int64 w = col_start; w <= col_end; ++w) { + const int64 in_offset = + (b * input_size_[1] + h) * input_size_[2] + w; + out_mat.col(out_offset) += in_mat.col(in_offset); + out_count(out_offset)++; + } + } + } + } + } + DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); + } + + private: + bool deterministic_; + // meaningful only when deterministic_ is true. + mutex mu_; + std::vector<int64> row_cum_seq_; + std::vector<int64> col_cum_seq_; + bool pooling_region_generated_; + + std::vector<int32> input_size_; + std::vector<int32> output_size_; + std::vector<float> pooling_ratio_; + bool pseudo_random_; + bool overlapping_; + GuardedPhiloxRandom generator_; +}; + +#define REGISTER_FRACTIONALAVGPOOL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("FractionalAvgPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + FractionalAvgPoolOp<type>) + +REGISTER_FRACTIONALAVGPOOL(int32); +REGISTER_FRACTIONALAVGPOOL(int64); +REGISTER_FRACTIONALAVGPOOL(float); +REGISTER_FRACTIONALAVGPOOL(double); + +#undef REGISTER_FRACTIONALAVGPOOL + +template <class T> +class FractionalAvgPoolGradOp : public OpKernel { + public: + explicit FractionalAvgPoolGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_)); + } + + void Compute(OpKernelContext* context) override { + // Here's the basic idea: + // Batch and depth dimension are independent from row and col dimension. And + // because FractionalAvgPool currently only support pooling along row and + // col, we can basically think of this 4D tensor backpropagation as + // operation of a series of 2D planes. + // + // For each element of a 'slice' (2D plane) of output_backprop, we need to + // figure out its contributors when doing FractionalAvgPool operation. This + // can be done based on row_pooling_sequence, col_pooling_seq and + // overlapping. + // Once we figure out the original contributors, we just need to evenly + // divide the value of this element among these contributors. + // + // Internally, we divide the out_backprop tensor and store it in a temparary + // tensor of double type. And cast it to the corresponding type. + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> + EigenDoubleMatrixMap; + + // Grab the inputs. + const Tensor& orig_input_tensor_shape = context->input(0); + OP_REQUIRES(context, orig_input_tensor_shape.dims() == 1 && + orig_input_tensor_shape.NumElements() == 4, + errors::InvalidArgument("original input tensor shape must be" + "1-dimensional and 4 elements")); + const Tensor& out_backprop = context->input(1); + const Tensor& row_seq_tensor = context->input(2); + const Tensor& col_seq_tensor = context->input(3); + + const int64 out_batch = out_backprop.dim_size(0); + const int64 out_rows = out_backprop.dim_size(1); + const int64 out_cols = out_backprop.dim_size(2); + const int64 out_depth = out_backprop.dim_size(3); + + auto row_seq_tensor_flat = row_seq_tensor.flat<int64>(); + auto col_seq_tensor_flat = col_seq_tensor.flat<int64>(); + auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>(); + + const int64 in_batch = orig_input_tensor_shape_flat(0); + const int64 in_rows = orig_input_tensor_shape_flat(1); + const int64 in_cols = orig_input_tensor_shape_flat(2); + const int64 in_depth = orig_input_tensor_shape_flat(3); + + constexpr int tensor_in_and_out_dims = 4; + // Transform orig_input_tensor_shape into TensorShape + TensorShape in_shape; + for (auto i = 0; i < tensor_in_and_out_dims; ++i) { + in_shape.AddDim(orig_input_tensor_shape_flat(i)); + } + + // Create intermediate in_backprop. + Tensor in_backprop_tensor_temp; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<double>::v(), in_shape, + &in_backprop_tensor_temp)); + in_backprop_tensor_temp.flat<double>().setZero(); + // Transform 4D tensor to 2D matrix. + EigenDoubleMatrixMap in_backprop_tensor_temp_mat( + in_backprop_tensor_temp.flat<double>().data(), in_depth, + in_cols * in_rows * in_batch); + ConstEigenMatrixMap out_backprop_mat(out_backprop.flat<T>().data(), + out_depth, + out_cols * out_rows * out_batch); + // Loop through each element of out_backprop and evenly distribute the + // element to the corresponding pooling cell. + const int64 in_max_row_index = in_rows - 1; + const int64 in_max_col_index = in_cols - 1; + for (int64 b = 0; b < out_batch; ++b) { + for (int64 r = 0; r < out_rows; ++r) { + const int64 in_row_start = row_seq_tensor_flat(r); + int64 in_row_end = overlapping_ ? row_seq_tensor_flat(r + 1) + : row_seq_tensor_flat(r + 1) - 1; + in_row_end = std::min(in_row_end, in_max_row_index); + for (int64 c = 0; c < out_cols; ++c) { + const int64 in_col_start = col_seq_tensor_flat(c); + int64 in_col_end = overlapping_ ? col_seq_tensor_flat(c + 1) + : col_seq_tensor_flat(c + 1) - 1; + in_col_end = std::min(in_col_end, in_max_col_index); + + const int64 num_elements_in_pooling_cell = + (in_row_end - in_row_start + 1) * (in_col_end - in_col_start + 1); + const int64 out_index = (b * out_rows + r) * out_cols + c; + // Now we can evenly distribute out_backprop(b, h, w, *) to + // in_backprop(b, hs:he, ws:we, *). + for (int64 in_r = in_row_start; in_r <= in_row_end; ++in_r) { + for (int64 in_c = in_col_start; in_c <= in_col_end; ++in_c) { + const int64 in_index = (b * in_rows + in_r) * in_cols + in_c; + // Walk through each channel (depth). + for (int64 d = 0; d < out_depth; ++d) { + const double out_backprop_element = static_cast<double>( + out_backprop_mat.coeffRef(d, out_index)); + double& in_backprop_ref = + in_backprop_tensor_temp_mat.coeffRef(d, in_index); + in_backprop_ref += + out_backprop_element / num_elements_in_pooling_cell; + } + } + } + } + } + } + + // Depending on the type, cast double to type T. + Tensor* in_backprop_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, in_shape, &in_backprop_tensor)); + auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>(); + auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>(); + for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) { + in_backprop_tensor_flat(i) = + static_cast<T>(in_backprop_tensor_temp_flat(i)); + } + } + + private: + bool overlapping_; +}; + +#define REGISTER_FRACTIONALAVGPOOLGRAD(type) \ + REGISTER_KERNEL_BUILDER(Name("FractionalAvgPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T"), \ + FractionalAvgPoolGradOp<type>) + +REGISTER_FRACTIONALAVGPOOLGRAD(int32); +REGISTER_FRACTIONALAVGPOOLGRAD(int64); +REGISTER_FRACTIONALAVGPOOLGRAD(float); +REGISTER_FRACTIONALAVGPOOLGRAD(double); + +#undef REGISTER_FRACTIONALAVGPOOLGRAD +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc new file mode 100644 index 00000000000..482491b504a --- /dev/null +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -0,0 +1,381 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include <algorithm> +#include <cmath> +#include <random> +#include <vector> + +#include "tensorflow/core/kernels/fractional_pool_common.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <typename T> +class FractionalMaxPoolOp : public OpKernel { + public: + explicit FractionalMaxPoolOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_)); + OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_)); + OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_)); + + OP_REQUIRES(context, pooling_ratio_.size() == 4, + errors::InvalidArgument("pooling_ratio field must " + "specify 4 dimensions")); + + OP_REQUIRES( + context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1, + errors::Unimplemented("Fractional max pooling is not yet " + "supported on the batch nor channel dimension.")); + + OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_)); + pooling_region_generated_ = false; + // Initialize philox random generator. + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + + constexpr int tensor_in_and_out_dims = 4; + + const Tensor& tensor_in = context->input(0); + OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + input_size_.push_back(tensor_in.dim_size(i)); + } + // Output size. + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + output_size_.push_back( + static_cast<int>(floor(input_size_[i] / pooling_ratio_[i]))); + DCHECK_GT(output_size_[i], 0); + } + + // Generate pooling sequence. + std::vector<int64> height_cum_seq; + std::vector<int64> width_cum_seq; + if (deterministic_) { + if (pooling_region_generated_) { + height_cum_seq = height_cum_seq_; + width_cum_seq = width_cum_seq_; + } else { + height_cum_seq = GeneratePoolingSequence( + input_size_[1], output_size_[1], &generator_, pseudo_random_); + width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], + &generator_, pseudo_random_); + mutex_lock lock(mu_); + height_cum_seq_ = height_cum_seq; + width_cum_seq_ = width_cum_seq; + pooling_region_generated_ = true; + } + } else { + height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1], + &generator_, pseudo_random_); + width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2], + &generator_, pseudo_random_); + } + + // Prepare output. + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({output_size_[0], output_size_[1], + output_size_[2], output_size_[3]}), + &output_tensor)); + Tensor* output_height_seq_tensor = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 1, TensorShape({static_cast<int64>(height_cum_seq.size())}), + &output_height_seq_tensor)); + Tensor* output_width_seq_tensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output( + 2, TensorShape({static_cast<int64>(width_cum_seq.size())}), + &output_width_seq_tensor)); + + ConstEigenMatrixMap in_mat( + tensor_in.flat<T>().data(), input_size_[3], + input_size_[2] * input_size_[1] * input_size_[0]); + + EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3], + output_size_[2] * output_size_[1] * output_size_[0]); + + // Initializes the output tensor with MIN<T>. + output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest()); + + auto output_height_seq_flat = output_height_seq_tensor->flat<int64>(); + auto output_width_seq_flat = output_width_seq_tensor->flat<int64>(); + + // Set output tensors. + for (int i = 0; i < height_cum_seq.size(); ++i) { + output_height_seq_flat(i) = height_cum_seq[i]; + } + + for (int i = 0; i < width_cum_seq.size(); ++i) { + output_width_seq_flat(i) = width_cum_seq[i]; + } + + // For both input and output, + // 0: batch + // 1: height / row + // 2: width / col + // 3: depth / channel + const int64 height_max = input_size_[1] - 1; + const int64 width_max = input_size_[2] - 1; + for (int64 b = 0; b < input_size_[0]; ++b) { + // height sequence. + for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) { + // height start and end. + const int64 height_start = height_cum_seq[hs]; + int64 height_end = + overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1; + height_end = std::min(height_end, height_max); + + // width sequence. + for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) { + const int64 out_offset = + (b * output_size_[1] + hs) * output_size_[2] + ws; + // width start and end. + const int64 width_start = width_cum_seq[ws]; + int64 width_end = + overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1; + width_end = std::min(width_end, width_max); + for (int64 h = height_start; h <= height_end; ++h) { + for (int64 w = width_start; w <= width_end; ++w) { + const int64 in_offset = + (b * input_size_[1] + h) * input_size_[2] + w; + out_mat.col(out_offset) = + out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); + } + } + } + } + } + } + + private: + bool deterministic_; + // meaningful only when deterministic_ is true. + mutex mu_; + std::vector<int64> height_cum_seq_; + std::vector<int64> width_cum_seq_; + bool pooling_region_generated_; + + std::vector<int32> input_size_; + std::vector<int32> output_size_; + std::vector<float> pooling_ratio_; + bool pseudo_random_; + bool overlapping_; + GuardedPhiloxRandom generator_; +}; + +#define REGISTER_FRACTIONALMAXPOOL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + FractionalMaxPoolOp<type>) + +REGISTER_FRACTIONALMAXPOOL(int32); +REGISTER_FRACTIONALMAXPOOL(int64); +REGISTER_FRACTIONALMAXPOOL(float); +REGISTER_FRACTIONALMAXPOOL(double); + +#undef REGISTER_FRACTIONALMAXPOOL + +static const int kInvalidMaxPoolingIndex = -1; + +template <class T> +class FractionalMaxPoolGradOp : public OpKernel { + public: + explicit FractionalMaxPoolGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_)); + } + + void Compute(OpKernelContext* context) override { + // There are two steps when calculating gradient for FractionalMaxPool. + // 1) Walk through the process of calculating fractional pooling given + // pooling region; however, in the process, keep track of where the max + // element comes from. (arg_max) + // 2) Populate the value of out_backprop to where arg_max indicates. If + // we support overlapping, it is likely to have multiple out_backprop[i] + // propagates back to the same arg_max value. + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>> + EigenIndexMatrixMap; + + const Tensor& tensor_in = context->input(0); + const Tensor& tensor_out = context->input(1); + const Tensor& out_backprop = context->input(2); + const Tensor& height_seq_tensor = context->input(3); + const Tensor& width_seq_tensor = context->input(4); + + // Just to make it similar to FractionalMaxPoolOp. + constexpr int tensor_in_and_out_dims = 4; + std::vector<int64> input_size; + std::vector<int64> output_size; + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + input_size.push_back(tensor_in.dim_size(i)); + } + for (int i = 0; i < tensor_in_and_out_dims; ++i) { + output_size.push_back(tensor_out.dim_size(i)); + } + + // --------- + // Step 1 + // --------- + Tensor tensor_out_dup; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::v(), + tensor_out.shape(), &tensor_out_dup)); + Tensor tensor_out_arg_max; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(), + tensor_out.shape(), + &tensor_out_arg_max)); + // Find arg_max for each tensor_out + ConstEigenMatrixMap tensor_in_mat( + tensor_in.flat<T>().data(), input_size[3], + input_size[2] * input_size[1] * input_size[0]); + EigenMatrixMap tensor_out_dup_mat( + tensor_out_dup.flat<T>().data(), output_size[3], + output_size[2] * output_size[1] * output_size[0]); + EigenIndexMatrixMap tensor_out_arg_max_mat( + tensor_out_arg_max.flat<int64>().data(), output_size[3], + output_size[2] * output_size[1] * output_size[0]); + + tensor_out_arg_max.flat<int64>().setConstant(kInvalidMaxPoolingIndex); + // Initializes the duplicate output tensor with MIN<T>. + tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest()); + + auto height_seq_tensor_flat = height_seq_tensor.flat<int64>(); + auto width_seq_tensor_flat = width_seq_tensor.flat<int64>(); + + // Now walk through the process of fractional max pooling again. + // For both input and output, + // 0: batch + // 1: height / row + // 2: width / col + // 3: depth / channel + const int64 height_max = input_size[1] - 1; + const int64 width_max = input_size[2] - 1; + for (int64 b = 0; b < input_size[0]; ++b) { + // height sequence. + for (int64 hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) { + // height start and end. + const int64 height_start = height_seq_tensor_flat(hs); + int64 height_end = overlapping_ ? height_seq_tensor_flat(hs + 1) + : height_seq_tensor_flat(hs + 1) - 1; + height_end = std::min(height_end, height_max); + + // width sequence. + for (int64 ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) { + const int64 out_index = + (b * output_size[1] + hs) * output_size[2] + ws; + // width start and end. + const int64 width_start = width_seq_tensor_flat(ws); + int64 width_end = overlapping_ ? width_seq_tensor_flat(ws + 1) + : width_seq_tensor_flat(ws + 1) - 1; + width_end = std::min(width_end, width_max); + for (int64 h = height_start; h <= height_end; ++h) { + for (int64 w = width_start; w <= width_end; ++w) { + const int64 in_index = + (b * input_size[1] + h) * input_size[2] + w; + // Walk through each channel (depth). + for (int64 d = 0; d < input_size[3]; ++d) { + const T& input_ref = tensor_in_mat.coeffRef(d, in_index); + T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index); + int64& out_arg_max_ref = + tensor_out_arg_max_mat.coeffRef(d, out_index); + if (output_ref < input_ref || + out_arg_max_ref == kInvalidMaxPoolingIndex) { + output_ref = input_ref; + int input_offset = in_index * input_size[3] + d; + out_arg_max_ref = input_offset; + } + } + } + } + } + } + } + + // Check tensor_out_dup is the same as tensor_out. + ConstEigenMatrixMap tensor_out_mat( + tensor_out.flat<T>().data(), output_size[3], + output_size[2] * output_size[1] * output_size[0]); + const int64 num_reshaped_cols = + output_size[2] * output_size[1] * output_size[0]; + for (int64 i = 0; i < num_reshaped_cols; ++i) { + for (int64 j = 0; j < output_size[3]; ++j) { + DCHECK_EQ(tensor_out_dup_mat(j, i), tensor_out_mat(j, i)); + } + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor_in.shape(), &output)); + output->flat<T>().setZero(); + + auto out_backprop_flat = out_backprop.flat<T>(); + auto input_backprop_flat = output->flat<T>(); + auto out_arg_max_flat = tensor_out_arg_max.flat<int64>(); + int num_total_outputs = out_backprop_flat.size(); + int num_total_inputs = input_backprop_flat.size(); + + for (int index = 0; index < num_total_outputs; ++index) { + int input_backprop_index = out_arg_max_flat(index); + // According to maxpooling_op.cc, the performance impact below is small. + CHECK(input_backprop_index >= 0 && + input_backprop_index < num_total_inputs) + << "Invalid input backprop index: " << input_backprop_index << ", " + << num_total_inputs; + input_backprop_flat(input_backprop_index) += out_backprop_flat(index); + } + } + + private: + bool overlapping_; +}; + +#define REGISTER_FRACTIONALMAXPOOLGRAD(type) \ + REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T"), \ + FractionalMaxPoolGradOp<type>) + +REGISTER_FRACTIONALMAXPOOLGRAD(int32); +REGISTER_FRACTIONALMAXPOOLGRAD(int64); +REGISTER_FRACTIONALMAXPOOLGRAD(float); +REGISTER_FRACTIONALMAXPOOLGRAD(double); + +#undef REGISTER_FRACTIONALMAXPOOLGRAD +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fractional_pool_common.cc b/tensorflow/core/kernels/fractional_pool_common.cc new file mode 100644 index 00000000000..f3a3dfda8df --- /dev/null +++ b/tensorflow/core/kernels/fractional_pool_common.cc @@ -0,0 +1,134 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/fractional_pool_common.h" + +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +static std::vector<int64> GeneratePoolingSequencePseudoRandom( + int input_length, int output_length, GuardedPhiloxRandom* generator) { + std::vector<int64> cum_seq(output_length + 1, 0); + std::vector<int64> diff(output_length, 0); + + double alpha = static_cast<double>(input_length) / output_length; + int k = input_length / output_length; + + // In the paper [1], author proposes the following procedure to generate a + // pseudo random pooling region: + // 1) Set a_0 = 1, a_Nout = Nin; + // 2) a_i = ceil(alpha*(u+i)) + // in which, i = 1, 2, ... , Nout-1 + // u is a random number in (0,1) for all i + // alpha = Nin/Nout in (1,2) + // The sequence {a_i} should satisfy a_i-a_{i-1} = 1 or 2 + // Note: for step 1), it makes more sense to make a_Nout = Nin+1, that way, + // a_i-a_{i-1} = 1 or 2 is also true for i = Nout. + // + // However, there are choices of alpha and u that will make + // a_i - a_{i-1} > 2. This happens at the left boundary. For example, with + // alpha = 1.732, u = 0.8, then a_1 = 4, a_1-a_0 = 3. + // This is why u_max1 is needed, i.e. u is a random number in (0,u_max1) + // instead of (0,1). + // Define k = ceil(alpha)-1, then we require: + // a_1 = alpha*(u+1) <= a_0+(k+1) + // ===> This gives u_max1 = (k+2)/alpha - 1. + // + // In addition, when extending the pooling sequence generation process for + // alpha beyond (1,2), e.g. (k,k+1); a check on the right boundary is also + // needed to make sure the last gap a_Nout-a_{Nout-1} >= k. Solving it gives + // u_max2 = (Nin+1-k)/alpha - (Nout-1) + // Here is an example where u > u_max2, alpha = 2.3, u = 0.7, u_max2 = 0.565, + // Nin = 23, Nout = 10; the sequence + // from a_0 to a_10 is: + // [1, 4, 7, 9, 11, 14, 16, 18, 21, 23, 24] + // The last gap is only 1. + // + // [1]: https://arxiv.org/abs/1412.6071 + double u_max1 = (k + 2) / alpha - 1; + double u_max2 = (input_length + 1 - k) / alpha - (output_length - 1); + double max_u = std::min(u_max1, u_max2); + + // Generate random number in parallel. + auto local_gen = generator->ReserveSamples32(2); + random::SimplePhilox random(&local_gen); + const double u = random.RandDouble() * max_u; + + cum_seq[0] = 1; + cum_seq[output_length] = input_length + 1; + for (int i = 1; i < output_length; ++i) { + cum_seq[i] = static_cast<int>(ceil(alpha * (i + u))); + } + + for (int i = 0; i < output_length; ++i) { + diff[i] = cum_seq[i + 1] - cum_seq[i]; + } + + return diff; +} + +static std::vector<int64> GeneratePoolingSequenceRandom( + int input_length, int output_length, GuardedPhiloxRandom* generator) { + int k = input_length / output_length; + int num_random_spot = input_length % output_length; + std::vector<int64> diff(output_length, k); + + for (int i = 0; i < num_random_spot; ++i) { + diff[i] += 1; + } + + // Randomly shuffle this vector. + auto local_gen = generator->ReserveSamples32(diff.size()); + random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen); + const auto uniform = [&single](uint32 n) { return single() % n; }; + RandomShuffle(diff.begin(), diff.end(), uniform); + + return diff; +} + +std::vector<int64> GeneratePoolingSequence(int input_length, int output_length, + GuardedPhiloxRandom* generator, + bool pseudo_random) { + std::vector<int64> diff; + // This is a case that regular pooling can handle, just return diff with + // each element input_length/output_length. + if (input_length % output_length == 0) { + diff = std::vector<int64>(output_length, input_length / output_length); + } + + if (pseudo_random) { + diff = GeneratePoolingSequencePseudoRandom(input_length, output_length, + generator); + } else { + diff = + GeneratePoolingSequenceRandom(input_length, output_length, generator); + } + + // Sanity check. + int k = input_length / output_length; + for (int i = 0; i < output_length; ++i) { + // k<= diff[i] <= k+1. + DCHECK_GE(diff[i], k); + DCHECK_LE(diff[i], k + 1); + } + + // Return cumulative sequence. + std::vector<int64> cum_seq(output_length + 1, 0); + for (int i = 1; i < cum_seq.size(); ++i) { + cum_seq[i] = cum_seq[i - 1] + diff[i - 1]; + } + return cum_seq; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h new file mode 100644 index 00000000000..df0bbbfa066 --- /dev/null +++ b/tensorflow/core/kernels/fractional_pool_common.h @@ -0,0 +1,78 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_ +#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_ + +#include <algorithm> +#include <vector> + +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +// Shuffle a container randomly, copied from random_shuffle_op.cc +template <class Iter, class Random> +static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) { + if (first == last) { + return; + } + const auto stop = last - 1; + for (auto i = first; i != stop; ++i) { + using std::iter_swap; + iter_swap(i, i + uniform(last - i)); + } +} + +// Generate pooling sequence for fractional pooling along one dimension. +// +// Regular max/avg pooling can be viewed as a special case, in which given the +// * input_length: e.g. 10 +// * output_length: e.g. 5 +// it will generate pooling sequence as +// diff sequence: [2, 2, 2, 2, 2] +// or as +// cumulative sequence: [0, 2, 4, 6, 8, 10] +// +// In the case of fractional pooling, input_length is not an integer multiple of +// output_length, randomness plays a role when generating pooling sequence. +// There are two type of randomness (random vs pseudo-random) defined in paper: +// http://arxiv.org/abs/1412.6071 +// You can check the paper for the difference between these two types. +// +// In summary, the generated diff sequence satisfy the following properties for +// both types of randomness: +// * length(generated_diff_pooling_sequence) = output_length +// * sum(generated_diff_pooling_sequence) = input_length +// * Let's define floor(input_length / output_length) = K, then +// K <= generated_diff_pooling_sequence[i] <= K+1 +// For example, when input_length = 10, output_length = 6, the followings are +// valid pooling sequence: +// * [1, 2, 2, 1, 2, 2] +// * [1, 1, 2, 2, 2, 2] +// [1, 3, 2, 2, 2, 2] is not valid. +// +// Args: +// input_length: See above explanation +// output_length: See above explanation +// generator: Parallel version of random number generator +// pseudo_random: Whether or not use pseudo-random +// Returns: +// pooling_sequence: This is the cumulative pooling sequence. +std::vector<int64> GeneratePoolingSequence(int input_length, int output_length, + GuardedPhiloxRandom* generator, + bool pseudo_random); +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_ diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index affbd269669..fe78541d08f 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1515,4 +1515,205 @@ values: The `k` largest elements along each last dimensional slice. indices: The indices of `values` within the last dimension of `input`. )doc"); +// -------------------------------------------------------------------------- + +REGISTER_OP("FractionalMaxPool") + .Input("value: T") + .Output("output: T") + .Output("row_pooling_sequence: int64") + .Output("col_pooling_sequence: int64") + .Attr("pooling_ratio: list(float) >=4") + .Attr("pseudo_random: bool = false") + .Attr("overlapping: bool = false") + .Attr("deterministic: bool = false") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Performs fractional max pooling on the input. + +Fractional max pooling is slightly different than regular max pooling. In +regular max pooling, you downsize an input set by taking the maximum value of +smaller N x N subsections of the set (often 2x2), and try to reduce the set by +a factor of N, where N is an integer. Fractional max pooling, as you might +expect from the word "fractional", means that the overall reduction ratio N +does not have to be an integer. + +The sizes of the pooling regions are generated randomly but are fairly uniform. +For example, let's look at the height dimension, and the constraints on the +list of rows that will be pool boundaries. + +First we define the following: + +1. input_row_length : the number of rows from the input set +2. output_row_length : which will be smaller than the input +3. alpha = input_row_length / output_row_length : our reduction ratio +4. K = floor(alpha) +5. row_pooling_sequence : this is the result list of pool boundary rows + +Then, row_pooling_sequence should satisfy: + +1. a[0] = 0 : the first value of the sequence is 0 +2. a[end] = input_row_length : the last value of the sequence is the size +3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +4. length(row_pooling_sequence) = output_row_length+1 + +For more details on fractional max pooling, see this paper: +[Benjamin Graham, Fractional Max-Pooling] +(http://arxiv.org/abs/1412.6071) + +value: 4-D with shape `[batch, height, width, channels]`. +pooling_ratio: Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +pseudo_random: When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +overlapping: When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [20, 16] for fractional max pooling. +deterministic: When set to True, a fixed pooling region will be used when + iterating over a FractionalMaxPool node in the computation graph. Mainly used + in unit test to make FractionalMaxPool deterministic. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +output: output tensor after fractional max pooling. +row_pooling_sequence: row pooling sequence, needed to calculate gradient. +col_pooling_sequence: column pooling sequence, needed to calculate gradient. +)doc"); + +REGISTER_OP("FractionalMaxPoolGrad") + .Input("orig_input: T") + .Input("orig_output: T") + .Input("out_backprop: T") + .Input("row_pooling_sequence: int64") + .Input("col_pooling_sequence: int64") + .Output("output: T") + .Attr("overlapping: bool = false") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Computes gradient of the FractionalMaxPool function. + +orig_input: Original input for `fractional_max_pool` +orig_output: Original output for `fractional_max_pool` +out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients + w.r.t. the output of `fractional_max_pool`. +row_pooling_sequence: row pooling sequence, form pooling region with + col_pooling_sequence. +col_pooling_sequence: column pooling sequence, form pooling region with + row_pooling sequence. +overlapping: When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [20, 16] for fractional max pooling. +output: 4-D. Gradients w.r.t. the input of `fractional_max_pool`. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("FractionalAvgPool") + .Input("value: T") + .Output("output: T") + .Output("row_pooling_sequence: int64") + .Output("col_pooling_sequence: int64") + .Attr("pooling_ratio: list(float) >=4") + .Attr("pseudo_random: bool = false") + .Attr("overlapping: bool = false") + .Attr("deterministic: bool = false") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Performs fractional average pooling on the input. + +Fractional average pooling is similar to Fractional max pooling in the pooling +region generation step. The only difference is that after pooling regions are +generated, a mean operation is performed instead of a max operation in each +pooling region. + +value: 4-D with shape `[batch, height, width, channels]`. +pooling_ratio: Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +pseudo_random: When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +overlapping: When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [41/3, 26/3] for fractional avg pooling. +deterministic: When set to True, a fixed pooling region will be used when + iterating over a FractionalAvgPool node in the computation graph. Mainly used + in unit test to make FractionalAvgPool deterministic. +seed: If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: An second seed to avoid seed collision. +output: output tensor after fractional avg pooling. +row_pooling_sequence: row pooling sequence, needed to calculate gradient. +col_pooling_sequence: column pooling sequence, needed to calculate gradient. +)doc"); + +REGISTER_OP("FractionalAvgPoolGrad") + .Input("orig_input_tensor_shape: int64") + .Input("out_backprop: T") + .Input("row_pooling_sequence: int64") + .Input("col_pooling_sequence: int64") + .Output("output: T") + .Attr("overlapping: bool = false") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Computes gradient of the FractionalAvgPool function. + +Unlike FractionalMaxPoolGrad, we don't need to find arg_max for +FractionalAvgPoolGrad, we just need to evenly back-propagate each element of +out_backprop to those indices that form the same pooling cell. Therefore, we +just need to know the shape of original input tensor, instead of the whole +tensor. + +orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` +out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients + w.r.t. the output of `fractional_avg_pool`. +row_pooling_sequence: row pooling sequence, form pooling region with + col_pooling_sequence. +col_pooling_sequence: column pooling sequence, form pooling region with + row_pooling sequence. +overlapping: When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [41/3, 26/3] for fractional avg pooling. +output: 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md new file mode 100644 index 00000000000..ef10897212f --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md @@ -0,0 +1,82 @@ +### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool} + +Performs fractional max pooling on the input. + +Fractional max pooling is slightly different than regular max pooling. In +regular max pooling, you downsize an input set by taking the maximum value of +smaller N x N subsections of the set (often 2x2), and try to reduce the set by +a factor of N, where N is an integer. Fractional max pooling, as you might +expect from the word "fractional", means that the overall reduction ratio N +does not have to be an integer. + +The sizes of the pooling regions are generated randomly but are fairly uniform. +For example, let's look at the height dimension, and the constraints on the +list of rows that will be pool boundaries. + +First we define the following: + +1. input_row_length : the number of rows from the input set +2. output_row_length : which will be smaller than the input +3. alpha = input_row_length / output_row_length : our reduction ratio +4. K = floor(alpha) +5. row_pooling_sequence : this is the result list of pool boundary rows + +Then, row_pooling_sequence should satisfy: + +1. a[0] = 0 : the first value of the sequence is 0 +2. a[end] = input_row_length : the last value of the sequence is the size +3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +4. length(row_pooling_sequence) = output_row_length+1 + +For more details on fractional max pooling, see this paper: +[Benjamin Graham, Fractional Max-Pooling] +(http://arxiv.org/abs/1412.6071) + +##### Args: + + +* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. + 4-D with shape `[batch, height, width, channels]`. +* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`. + Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`. + When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`. + When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [20, 16] for fractional max pooling. + +* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`. + When set to True, a fixed pooling region will be used when + iterating over a FractionalMaxPool node in the computation graph. Mainly used + in unit test to make FractionalMaxPool deterministic. +* <b>`seed`</b>: An optional `int`. Defaults to `0`. + If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +* <b>`seed2`</b>: An optional `int`. Defaults to `0`. + An second seed to avoid seed collision. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence). + +* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling. +* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient. +* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md new file mode 100644 index 00000000000..595e6649739 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md @@ -0,0 +1,57 @@ +### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool} + +Performs fractional average pooling on the input. + +Fractional average pooling is similar to Fractional max pooling in the pooling +region generation step. The only difference is that after pooling regions are +generated, a mean operation is performed instead of a max operation in each +pooling region. + +##### Args: + + +* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. + 4-D with shape `[batch, height, width, channels]`. +* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`. + Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`. + When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`. + When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [41/3, 26/3] for fractional avg pooling. + +* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`. + When set to True, a fixed pooling region will be used when + iterating over a FractionalAvgPool node in the computation graph. Mainly used + in unit test to make FractionalAvgPool deterministic. +* <b>`seed`</b>: An optional `int`. Defaults to `0`. + If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +* <b>`seed2`</b>: An optional `int`. Defaults to `0`. + An second seed to avoid seed collision. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence). + +* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling. +* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient. +* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient. + diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 4daa212c33e..9bb9236ca21 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -473,6 +473,8 @@ * [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse) * [`erosion2d`](../../api_docs/python/nn.md#erosion2d) * [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler) + * [`fractional_avg_pool`](../../api_docs/python/nn.md#fractional_avg_pool) + * [`fractional_max_pool`](../../api_docs/python/nn.md#fractional_max_pool) * [`in_top_k`](../../api_docs/python/nn.md#in_top_k) * [`l2_loss`](../../api_docs/python/nn.md#l2_loss) * [`l2_normalize`](../../api_docs/python/nn.md#l2_normalize) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 5bf8b5e6c79..b87f176e2bc 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -828,6 +828,151 @@ Performs 3D max pooling on the input. A `Tensor`. Has the same type as `input`. The max pooled output tensor. +- - - + +### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool} + +Performs fractional average pooling on the input. + +Fractional average pooling is similar to Fractional max pooling in the pooling +region generation step. The only difference is that after pooling regions are +generated, a mean operation is performed instead of a max operation in each +pooling region. + +##### Args: + + +* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. + 4-D with shape `[batch, height, width, channels]`. +* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`. + Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`. + When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`. + When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [41/3, 26/3] for fractional avg pooling. + +* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`. + When set to True, a fixed pooling region will be used when + iterating over a FractionalAvgPool node in the computation graph. Mainly used + in unit test to make FractionalAvgPool deterministic. +* <b>`seed`</b>: An optional `int`. Defaults to `0`. + If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +* <b>`seed2`</b>: An optional `int`. Defaults to `0`. + An second seed to avoid seed collision. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence). + +* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling. +* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient. +* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient. + + +- - - + +### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool} + +Performs fractional max pooling on the input. + +Fractional max pooling is slightly different than regular max pooling. In +regular max pooling, you downsize an input set by taking the maximum value of +smaller N x N subsections of the set (often 2x2), and try to reduce the set by +a factor of N, where N is an integer. Fractional max pooling, as you might +expect from the word "fractional", means that the overall reduction ratio N +does not have to be an integer. + +The sizes of the pooling regions are generated randomly but are fairly uniform. +For example, let's look at the height dimension, and the constraints on the +list of rows that will be pool boundaries. + +First we define the following: + +1. input_row_length : the number of rows from the input set +2. output_row_length : which will be smaller than the input +3. alpha = input_row_length / output_row_length : our reduction ratio +4. K = floor(alpha) +5. row_pooling_sequence : this is the result list of pool boundary rows + +Then, row_pooling_sequence should satisfy: + +1. a[0] = 0 : the first value of the sequence is 0 +2. a[end] = input_row_length : the last value of the sequence is the size +3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +4. length(row_pooling_sequence) = output_row_length+1 + +For more details on fractional max pooling, see this paper: +[Benjamin Graham, Fractional Max-Pooling] +(http://arxiv.org/abs/1412.6071) + +##### Args: + + +* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`. + 4-D with shape `[batch, height, width, channels]`. +* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`. + Pooling ratio for each dimension of `value`, currently only + supports row and col dimension and should be >= 1.0. For example, a valid + pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + must be 1.0 because we don't allow pooling on batch and channels + dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + respectively. +* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`. + When set to True, generates the pooling sequence in a + pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for + difference between pseudorandom and random. +* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`. + When set to True, it means when pooling, the values at the boundary + of adjacent pooling cells are used by both cells. For example: + + `index 0 1 2 3 4` + + `value 20 5 16 3 7` + + If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + The result would be [20, 16] for fractional max pooling. + +* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`. + When set to True, a fixed pooling region will be used when + iterating over a FractionalMaxPool node in the computation graph. Mainly used + in unit test to make FractionalMaxPool deterministic. +* <b>`seed`</b>: An optional `int`. Defaults to `0`. + If either seed or seed2 are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +* <b>`seed2`</b>: An optional `int`. Defaults to `0`. + An second seed to avoid seed collision. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence). + +* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling. +* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient. +* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient. + + ## Morphological filtering diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index effe925df06..f1e5b042ef7 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -36,6 +36,8 @@ py_tests( "determinant_op_test.py", "edit_distance_op_test.py", "fifo_queue_test.py", + "fractional_avg_pool_op_test.py", + "fractional_max_pool_op_test.py", "identity_op_py_test.py", "in_topk_op_test.py", "io_ops_test.py", diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py new file mode 100644 index 00000000000..dafecf27288 --- /dev/null +++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py @@ -0,0 +1,521 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for fractional average pool operation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +import tensorflow as tf + +from tensorflow.python.ops import gen_nn_ops + + +class FractionalAvgTest(tf.test.TestCase): + + # Random number generate with seed. + _PRNG = np.random.RandomState(341261000) + _SEED = 341261001 + _SEED2 = 341261002 + + def _AvgPoolAlongRows(self, input_matrix, row_seq, overlapping): + """Perform average pool along row of a 2-D matrix based on row_seq. + + Args: + input_matrix: A 2-D matrix. + row_seq: Cumulative pooling sequence along row. + overlapping: Whether or not use overlapping when pooling. + + Returns: + A 2-D matrix, with + * num_rows = len(row_seq)-1 + * num_cols = input_matrix.num_cols. + """ + output_image = np.zeros(input_matrix.shape[1]) + row_max = row_seq[-1] + for i in range(row_seq.shape[0] - 1): + row_start = row_seq[i] + row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1] + row_end = min(row_end, row_max) + output_image = np.vstack((output_image, + np.mean(input_matrix[row_start:row_end, :], + axis=0))) # axis 0 is along row + # remove the sentinel row + return output_image[1:, :] + + def _AvgPoolAlongCols(self, input_matrix, col_seq, overlapping): + """Perform average pool along column of a 2-D matrix based on col_seq. + + Args: + input_matrix: A 2-D matrix. + col_seq: Cumulative pooling sequence along column. + overlapping: Whether or not use overlapping when pooling. + + Returns: + A 2-D matrix, with + * num_rows = input_matrix.num_rows + * num_cols = len(col_seq)-1. + """ + input_matrix = input_matrix.transpose() + output_matrix = self._AvgPoolAlongRows(input_matrix, col_seq, overlapping) + return output_matrix.transpose() + + def _GetExpectedFractionalAvgPoolResult(self, input_tensor, row_seq, col_seq, + overlapping): + """Get expected fractional average pooling result. + + row_seq and col_seq together defines the fractional pooling region. + + Args: + input_tensor: Original input tensor, assuming it is a 4-D tensor, with + dimension as [batch, height/row, width/column, channels/depth]. + row_seq: Cumulative pooling sequence along row. + col_seq: Cumulative pooling sequence along column. + overlapping: Use overlapping when doing pooling. + + Returns: + A 4-D tensor that is the result of average pooling on input_tensor based + on pooling region defined by row_seq and col_seq, conditioned on whether + or not overlapping is used. + """ + input_shape = input_tensor.shape + output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1, + input_shape[3]) + output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype) + for batch in range(input_shape[0]): + for channel in range(input_shape[3]): + two_dim_slice = input_tensor[batch, :, :, channel] + tmp = self._AvgPoolAlongRows(two_dim_slice, row_seq, overlapping) + output_tensor[batch, :, :, channel] = self._AvgPoolAlongCols( + tmp, col_seq, overlapping) + + return output_tensor + + def _ValidateFractionalAvgPoolResult(self, input_tensor, pooling_ratio, + pseudo_random, overlapping): + """Validate FractionalAvgPool's result against expected. + + Expected result is computed given input_tensor, and pooling region defined + by row_seq and col_seq. + + Args: + input_tensor: A tensor or numpy ndarray. + pooling_ratio: A list or tuple of length 4, first and last element be 1. + pseudo_random: Use pseudo random method to generate pooling sequence. + overlapping: Use overlapping when pooling. + + Returns: + None + """ + with self.test_session() as sess: + p, r, c = tf.nn.fractional_avg_pool(input_tensor, + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + actual, row_seq, col_seq = sess.run([p, r, c]) + expected = self._GetExpectedFractionalAvgPoolResult(input_tensor, row_seq, + col_seq, overlapping) + self.assertShapeEqual(expected, p) + self.assertAllClose(expected, actual) + + def _testVisually(self): + """Manual test by printing out intermediate result of a small random tensor. + + Since _GetExpectedFractionalAvgPoolResult is 'automated', it feels safer to + have a test case that you can see what's happening. + This test will generate a small, random, int 2D matrix, and feed it to + FractionalAvgPool and _GetExpectedFractionalAvgPoolResult. + """ + num_rows = 6 + num_cols = 6 + tensor_shape = (1, num_rows, num_cols, 1) + pseudo_random = False + for overlapping in True, False: + print("-" * 70) + print("Testing FractionalAvgPool with overlapping = {}".format( + overlapping)) + rand_mat = self._PRNG.randint(10, size=tensor_shape) + pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1] + with self.test_session() as sess: + p, r, c = tf.nn.fractional_avg_pool( + rand_mat.astype(np.float32), + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + tensor_output, row_seq, col_seq = sess.run([p, r, c]) + expected_result = self._GetExpectedFractionalAvgPoolResult( + rand_mat.astype(np.float32), row_seq, col_seq, overlapping) + print("row sequence:") + print(row_seq) + print("column sequence:") + print(col_seq) + + print("Input:") + # Print input with pooling region marked. + for i in range(num_rows): + row_to_print = [] + for j in range(num_cols): + if j in col_seq: + row_to_print.append("|") + row_to_print.append(str(rand_mat[0, i, j, 0])) + row_to_print.append("|") + if i in row_seq: + print("-" * 2 * len(row_to_print)) + print(" ".join(row_to_print)) + print("-" * 2 * len(row_to_print)) + + print("Output from FractionalAvgPool:") + print(tensor_output[0, :, :, 0]) + print("Expected result:") + print(expected_result[0, :, :, 0]) + + def testAllInputOptions(self): + """Try all possible input options for fractional_avg_pool. + """ + num_batches = 5 + num_channels = 3 + num_rows = 20 + num_cols = 30 + for pseudo_random in True, False: + for overlapping in True, False: + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalAvgPoolResult( + rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random, + overlapping) + + def testIntegerTensorInput(self): + """Test FractionalAvgPool works fine when input tensor is integer type. + + I would have used _ValidateFractionalAvgPoolResult function to automate this + process, however, there's rounding issue. It is caused by numpy.mean cast + integer input to numpy.float64 for intermediate use. While for + fractional_avg_pool, the mean operation is integer division (trucated). So, + for this test case, I will hard code a simple matrix. + """ + pseudo_random = True + overlapping = True + tensor_shape = (1, 6, 6, 1) + # pyformat: disable + mat = np.array([ + [2, 6, 4, 1, 3, 6], + [8, 9, 1, 6, 6, 8], + [3, 9, 8, 2, 5, 6], + [2, 7, 9, 5, 4, 5], + [8, 5, 0, 5, 7, 4], + [4, 4, 5, 9, 7, 2] + ]) + # pyformat: enable + with self.test_session() as sess: + # Since deterministic = True, seed and seed2 are fixed. Therefore r, and c + # are the same each time. We can have an expected result precomputed. + # r = [0, 2, 4, 6] + # c = [0, 1, 3, 4, 6] + + # pyformat: disable + expected = np.array([ + [6, 5, 3, 5], + [5, 5, 4, 5], + [5, 4, 7, 5] + ]).reshape((1, 3, 4, 1)) + # pyformat: enable + p, unused_r, unused_c = tf.nn.fractional_avg_pool( + mat.reshape(tensor_shape), [1, math.sqrt(3), math.sqrt(2), 1], + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + actual = sess.run(p) + self.assertShapeEqual(expected, p) + self.assertAllClose(expected, actual) + + def testDifferentTensorShapes(self): + """Test different shapes of input tensor. + + Mainly test different combinations of num_rows and num_cols. + """ + pseudo_random = True + overlapping = True + for num_batches in [1, 3]: + for num_channels in [1, 3]: + for num_rows in [10, 20, 50]: + for num_cols in [10, 20, 50]: + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalAvgPoolResult( + rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random, + overlapping) + + def testLargePoolingRatio(self): + """Test when pooling ratio is not within [1, 2). + """ + pseudo_random = True + overlapping = True + num_batches = 3 + num_channels = 3 + num_rows = 30 + num_cols = 50 + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + for row_ratio in [math.sqrt(11), math.sqrt(37)]: + for col_ratio in [math.sqrt(11), math.sqrt(27)]: + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalAvgPoolResult(rand_mat, + [1, row_ratio, col_ratio, 1], + pseudo_random, overlapping) + + def testDivisiblePoolingRatio(self): + """Test when num of rows/cols can evenly divide pooling ratio. + + This is a case regular average pooling can handle. Should be handled by + fractional pooling as well. + """ + pseudo_random = True + overlapping = True + num_batches = 3 + num_channels = 3 + num_rows = 30 + num_cols = 50 + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random, + overlapping) + + +class FractionalAvgPoolGradTest(tf.test.TestCase): + """Tests for FractionalAvgPoolGrad. + + Two types of tests for FractionalAvgPoolGrad. + 1) Test fractional_avg_pool_grad() directly. + This type of test relies on gen_nn_ops._avg_pool_grad() returns the + correct result. For example: + * input_tensor_shape = (1, 10, 10, 1) + * window_size = (1, 2, 2, 1) + * stride_size = (1, 2, 2, 1) + * padding: not really important, since 10/2 is divisible + avg pooling should generate the same result as fractional avg pooling with: + * row_sequence = [0, 2, 4, 6, 8, 10] + * col_sequence = [0, 2, 4, 6, 8, 10] + * overlapping = False + This also means their gradients in such case will be the same. + + Similarly, when + * input_tensor_shape = (1, 7, 7, 1) + * window_size = (1, 3, 3, 1) + * stride_size = (1, 2, 2, 1) + * padding: not important + avg pooling should generate the same result as fractional avg pooling with: + * row_sequence = [0, 2, 4, 7] + * col_sequence = [0, 2, 4, 7] + * overlapping = True + 2) Test through compute_gradient_error() + """ + _PRNG = np.random.RandomState(341261004) + _SEED = 341261005 + _SEED2 = 341261006 + + def _GenerateRandomInputTensor(self, shape): + num_elements = 1 + for dim_size in shape: + num_elements *= dim_size + x = self._PRNG.rand(num_elements) * 1000 + return x.reshape(shape) + + def testDirectNotUseOverlapping(self): + for num_batches in [1, 3]: + for row_window_size in [2, 5]: + for col_window_size in [2, 4]: + num_rows = row_window_size * 5 + num_cols = col_window_size * 7 + for num_channels in [1, 2]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + with self.test_session() as _: + input_tensor = tf.constant(self._GenerateRandomInputTensor( + input_shape).astype(np.float32)) + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size, col_window_size, 1] + padding = "VALID" + output_tensor = tf.nn.avg_pool(input_tensor, window_size, + stride_size, padding) + output_data = output_tensor.eval() + num_elements = 1 + for dim_size in output_data.shape: + num_elements *= dim_size + output_backprop = (self._PRNG.rand(num_elements) * + 1000).reshape(output_data.shape) + input_backprop_tensor = gen_nn_ops._avg_pool_grad( + input_tensor.get_shape(), output_backprop, window_size, + stride_size, padding) + input_backprop = input_backprop_tensor.eval() + row_seq = list(range(0, num_rows + 1, row_window_size)) + col_seq = list(range(0, num_cols + 1, col_window_size)) + fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad( + input_tensor.get_shape(), + output_backprop, + row_seq, + col_seq, + overlapping=False) + fap_input_backprop = fap_input_backprop_tensor.eval() + self.assertShapeEqual(input_backprop, fap_input_backprop_tensor) + self.assertAllClose(input_backprop, fap_input_backprop) + + def testDirectUseOverlapping(self): + for num_batches in [1, 3]: + for row_window_size in [2, 5]: + for col_window_size in [2, 4]: + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + for num_channels in [1, 2]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + with self.test_session() as _: + input_tensor = tf.constant(self._GenerateRandomInputTensor( + input_shape).astype(np.float32)) + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + padding = "VALID" + output_tensor = tf.nn.avg_pool(input_tensor, window_size, + stride_size, padding) + output_data = output_tensor.eval() + num_elements = 1 + for dim_size in output_data.shape: + num_elements *= dim_size + output_backprop = (self._PRNG.rand(num_elements) * + 1000).reshape(output_data.shape) + input_backprop_tensor = gen_nn_ops._avg_pool_grad( + input_tensor.get_shape(), output_backprop, window_size, + stride_size, padding) + input_backprop = input_backprop_tensor.eval() + row_seq = list(range(0, num_rows, row_window_size - 1)) + col_seq = list(range(0, num_cols, col_window_size - 1)) + row_seq[-1] += 1 + col_seq[-1] += 1 + fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad( + input_tensor.get_shape(), + output_backprop, + row_seq, + col_seq, + overlapping=True) + fap_input_backprop = fap_input_backprop_tensor.eval() + self.assertShapeEqual(input_backprop, fap_input_backprop_tensor) + self.assertAllClose(input_backprop, fap_input_backprop) + + def testAllInputOptionsThroughGradientError(self): + input_shape = (1, 7, 13, 1) + input_data = self._GenerateRandomInputTensor(input_shape) + pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1] + + for pseudo_random in True, False: + for overlapping in True, False: + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + output_data = output_tensor.eval() + output_shape = output_data.shape + # error_margin and delta setting is similar to avg_pool_grad. + error_margin = 1e-4 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + def testDifferentTensorShapesThroughGradientError(self): + pseudo_random = True + overlapping = True + pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1] + for num_batches in [1, 2]: + for num_rows in [5, 13]: + for num_cols in [5, 11]: + for num_channels in [1, 3]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + input_data = self._GenerateRandomInputTensor(input_shape) + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + output_data = output_tensor.eval() + output_shape = output_data.shape + # error_margin and delta setting is similar to avg_pool_grad. + error_margin = 1e-4 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + def testLargePoolingRatioThroughGradientError(self): + input_shape = (1, 17, 23, 1) + input_data = self._GenerateRandomInputTensor(input_shape) + pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1) + output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)] + overlapping = True + pseudo_random = False + + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + # error_margin and delta setting is similar to avg_pool_grad. + error_margin = 1e-4 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py new file mode 100644 index 00000000000..424f4d588ab --- /dev/null +++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py @@ -0,0 +1,582 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for fractional max pool operation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +import tensorflow as tf + +from tensorflow.python.ops import gen_nn_ops + + +class FractionalMaxPoolTest(tf.test.TestCase): + + # Random number generate with seed. + _PRNG = np.random.RandomState(341261) + _SEED = 123456 + _SEED2 = 654321 + + def _MaxPoolAlongRows(self, input_matrix, row_seq, overlapping): + """Perform max pool along row of a 2-D matrix based on row_seq. + + Args: + input_matrix: A 2-D matrix. + row_seq: Cumulative pooling sequence along row. + overlapping: Whether or not use overlapping when pooling. + + Returns: + A 2-D matrix, with + * num_rows = len(row_seq)-1 + * num_cols = input_matrix.num_cols. + """ + output_image = np.zeros(input_matrix.shape[1]) + row_max = row_seq[-1] + for i in range(row_seq.shape[0] - 1): + row_start = row_seq[i] + row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1] + row_end = min(row_end, row_max) + output_image = np.vstack((output_image, + np.amax(input_matrix[row_start:row_end, :], + axis=0))) # axis 0 is along row + # remove the sentinel row + return output_image[1:, :] + + def _MaxPoolAlongCols(self, input_matrix, col_seq, overlapping): + """Perform max pool along column of a 2-D matrix based on col_seq. + + Args: + input_matrix: A 2-D matrix. + col_seq: Cumulative pooling sequence along column. + overlapping: Whether or not use overlapping when pooling. + + Returns: + A 2-D matrix, with + * num_rows = input_matrix.num_rows + * num_cols = len(col_seq)-1. + """ + input_matrix = input_matrix.transpose() + output_matrix = self._MaxPoolAlongRows(input_matrix, col_seq, overlapping) + return output_matrix.transpose() + + def _GetExpectedFractionalMaxPoolResult(self, input_tensor, row_seq, col_seq, + overlapping): + """Get expected fractional max pool result. + + row_seq and col_seq together defines the fractional pooling region. + + Args: + input_tensor: Original input tensor, assuming it is a 4-D tensor, with + dimension as [batch, height/row, width/column, channels/depth]. + row_seq: Cumulative pooling sequence along row. + col_seq: Cumulative pooling sequence along column. + overlapping: Use overlapping when doing pooling. + + Returns: + A 4-D tensor that is the result of max pooling on input_tensor based on + pooling region defined by row_seq and col_seq, conditioned on whether or + not overlapping is used. + """ + input_shape = input_tensor.shape + output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1, + input_shape[3]) + output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype) + for batch in range(input_shape[0]): + for channel in range(input_shape[3]): + two_dim_slice = input_tensor[batch, :, :, channel] + tmp = self._MaxPoolAlongRows(two_dim_slice, row_seq, overlapping) + output_tensor[batch, :, :, channel] = self._MaxPoolAlongCols( + tmp, col_seq, overlapping) + + return output_tensor + + def _ValidateFractionalMaxPoolResult(self, input_tensor, pooling_ratio, + pseudo_random, overlapping): + """Validate FractionalMaxPool's result against expected. + + Expected result is computed given input_tensor, and pooling region defined + by row_seq and col_seq. + + Args: + input_tensor: A tensor or numpy ndarray. + pooling_ratio: A list or tuple of length 4, first and last element be 1. + pseudo_random: Use pseudo random method to generate pooling sequence. + overlapping: Use overlapping when pooling. + + Returns: + None + """ + with self.test_session() as sess: + p, r, c = tf.nn.fractional_max_pool(input_tensor, + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + actual, row_seq, col_seq = sess.run([p, r, c]) + expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq, + col_seq, overlapping) + self.assertShapeEqual(expected, p) + self.assertAllClose(expected, actual) + + def _testVisually(self): + """Manual test by printing out intermediate result of a small random tensor. + + Since _GetExpectedFractionalMaxPoolResult is 'automated', it feel safer to + have a test case that you can see what's happening. + This test will generate a small, random, int 2D matrix, and feed it to + FractinalMaxPool and _GetExpectedFractionalMaxPoolResult. + """ + num_rows = 6 + num_cols = 6 + tensor_shape = (1, num_rows, num_cols, 1) + pseudo_random = False + for overlapping in True, False: + print("-" * 70) + print("Testing FractionalMaxPool with overlapping = {}".format( + overlapping)) + rand_mat = self._PRNG.randint(10, size=tensor_shape) + pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1] + with self.test_session() as sess: + p, r, c = tf.nn.fractional_max_pool(rand_mat, + pooling_ratio, + pseudo_random, + overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + tensor_output, row_seq, col_seq = sess.run([p, r, c]) + expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat, + row_seq, + col_seq, + overlapping) + print("row sequence:") + print(row_seq) + print("column sequence:") + print(col_seq) + + print("Input:") + # Print input with pooling region marked. + for i in range(num_rows): + row_to_print = [] + for j in range(num_cols): + if j in col_seq: + row_to_print.append("|") + row_to_print.append(str(rand_mat[0, i, j, 0])) + row_to_print.append("|") + if i in row_seq: + print("-" * 2 * len(row_to_print)) + print(" ".join(row_to_print)) + print("-" * 2 * len(row_to_print)) + + print("Output from FractionalMaxPool:") + print(tensor_output[0, :, :, 0]) + print("Expected result:") + print(expected_result[0, :, :, 0]) + + def testAllInputOptions(self): + """Try all possible input options for fractional_max_pool. + """ + num_batches = 5 + num_channels = 3 + num_rows = 20 + num_cols = 30 + for pseudo_random in True, False: + for overlapping in True, False: + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalMaxPoolResult( + rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random, + overlapping) + + def testIntegerTensorInput(self): + """Test it works fine when input tensor is integer type. + """ + num_batches = 5 + num_channels = 3 + num_rows = 20 + num_cols = 30 + pseudo_random = True + overlapping = True + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + rand_mat = self._PRNG.randint(1000, size=tensor_shape) + self._ValidateFractionalMaxPoolResult(rand_mat, + [1, math.sqrt(3), math.sqrt(2), 1], + pseudo_random, overlapping) + + def testDifferentTensorShapes(self): + """Test different shapes of input tensor. + + Mainly test different combinations of num_rows and num_cols. + """ + pseudo_random = True + overlapping = True + for num_batches in [1, 3]: + for num_channels in [1, 3]: + for num_rows in [10, 20, 50]: + for num_cols in [10, 20, 50]: + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalMaxPoolResult( + rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random, + overlapping) + + def testLargePoolingRatio(self): + """Test when pooling ratio is not within [1, 2). + """ + pseudo_random = True + overlapping = True + num_batches = 3 + num_channels = 3 + num_rows = 30 + num_cols = 50 + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + for row_ratio in [math.sqrt(11), math.sqrt(37)]: + for col_ratio in [math.sqrt(11), math.sqrt(27)]: + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalMaxPoolResult(rand_mat, + [1, row_ratio, col_ratio, 1], + pseudo_random, overlapping) + + def testDivisiblePoolingRatio(self): + """Test when num of rows/cols can evenly divide pooling ratio. + + This is a case regular max pooling can handle. Should be handled by + fractional pooling as well. + """ + pseudo_random = True + overlapping = True + num_batches = 3 + num_channels = 3 + num_rows = 30 + num_cols = 50 + tensor_shape = (num_batches, num_rows, num_cols, num_channels) + # random tensor with value in [-500.0, 500.0) + rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500 + self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random, + overlapping) + + +class FractionalMaxPoolGradTest(tf.test.TestCase): + """Tests for FractionalMaxPoolGrad. + + Two types of tests for FractionalMaxPoolGrad. + 1) Test fractional_max_pool_grad() directly. + This type of test relies on gen_nn_ops._max_pool_grad() returns the correct + result. For example: + * input_tensor_shape = (1, 10, 10, 1) + * window_size = (1, 2, 2, 1) + * stride_size = (1, 2, 2, 1) + * padding: not really import, since 10/2 is divisible + max pooling should generate the same result as fractional max pooling with: + * row_sequence = [0, 2, 4, 6, 8, 10] + * col_sequence = [0, 2, 4, 6, 8, 10] + * overlapping = False + This also means their gradients in such case will be the same. + + Similarly, when + * input_tensor_shape = (1, 7, 7, 1) + * window_size = (1, 3, 3, 1) + * stride_size = (1, 2, 2, 1) + * padding: not important + max pooling should generate the same result as fractional max pooling with: + * row_sequence = [0, 2, 4, 7] + * col_sequence = [0, 2, 4, 7] + * overlapping = True + 2) Test through compute_gradient_error() + """ + + _PRNG = np.random.RandomState(341261) + _SEED = 123456 + _SEED2 = 654321 + + def _GenerateUniqueRandomInputTensor(self, shape): + """Generate 'unqiue' random input tensor. + + 'Unique' means there's no collision values in the tensor, all elements are + different. This is done by generating sequence of integers with step of 1 + and then randomly shuffle these integers. + + Args: + shape: Shape of the tensor desired. + + Returns: + A numpy ndarray with size = shape and dtype = numpy.float32. + """ + num_elements = 1 + for size in shape: + num_elements *= size + x = np.arange(num_elements, dtype=np.float32) + self._PRNG.shuffle(x) + return x.reshape(shape) + + def testDirectNotUseOverlapping(self): + for num_batches in [1, 3]: + for row_window_size in [2, 5]: + for col_window_size in [2, 4]: + num_rows = row_window_size * 5 + num_cols = col_window_size * 7 + for num_channels in [1, 2]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + with self.test_session() as _: + input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor( + input_shape)) + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size, col_window_size, 1] + padding = "VALID" + output_tensor = tf.nn.max_pool(input_tensor, window_size, + stride_size, padding) + output_data = output_tensor.eval() + output_backprop = self._PRNG.randint(100, size=output_data.shape) + input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor, + output_tensor, + output_backprop, + window_size, + stride_size, + padding) + input_backprop = input_backprop_tensor.eval() + row_seq = list(range(0, num_rows + 1, row_window_size)) + col_seq = list(range(0, num_cols + 1, col_window_size)) + fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad( + input_tensor, + output_tensor, + output_backprop, + row_seq, + col_seq, + overlapping=False) + fmp_input_backprop = fmp_input_backprop_tensor.eval() + self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor) + self.assertAllClose(input_backprop, fmp_input_backprop) + + def testDirectUseOverlapping(self): + for num_batches in [1, 3]: + for row_window_size in [2, 5]: + for col_window_size in [2, 4]: + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + for num_channels in [1, 2]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + with self.test_session() as _: + input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor( + input_shape)) + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + padding = "VALID" + output_tensor = tf.nn.max_pool(input_tensor, window_size, + stride_size, padding) + output_data = output_tensor.eval() + output_backprop = self._PRNG.randint(100, size=output_data.shape) + input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor, + output_tensor, + output_backprop, + window_size, + stride_size, + padding) + input_backprop = input_backprop_tensor.eval() + row_seq = list(range(0, num_rows, row_window_size - 1)) + col_seq = list(range(0, num_cols, col_window_size - 1)) + row_seq[-1] += 1 + col_seq[-1] += 1 + fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad( + input_tensor, + output_tensor, + output_backprop, + row_seq, + col_seq, + overlapping=True) + fmp_input_backprop = fmp_input_backprop_tensor.eval() + self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor) + self.assertAllClose(input_backprop, fmp_input_backprop) + + def testAllInputOptionsThroughGradientError(self): + input_shape = (1, 7, 13, 1) + input_data = self._GenerateUniqueRandomInputTensor(input_shape) + # Add some randomness to make input_data not so 'integer' + input_data += self._PRNG.random_sample(input_shape) + pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1] + + for pseudo_random in True, False: + for overlapping in True, False: + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + output_data = output_tensor.eval() + output_shape = output_data.shape + # error_margin and delta setting is similar to max_pool_grad. + error_margin = 1e-3 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + def testDifferentTensorShapesThroughGradientError(self): + pseudo_random = True + overlapping = True + pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1] + for num_batches in [1, 2]: + for num_rows in [5, 13]: + for num_cols in [5, 11]: + for num_channels in [1, 3]: + input_shape = (num_batches, num_rows, num_cols, num_channels) + input_data = self._GenerateUniqueRandomInputTensor(input_shape) + # Add some randomness to make input_data not so 'integer' + input_data += self._PRNG.random_sample(input_shape) + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + output_data = output_tensor.eval() + output_shape = output_data.shape + # error_margin and delta setting is similar to max_pool_grad. + error_margin = 1e-3 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + def testLargePoolingRatioThroughGradientError(self): + input_shape = (1, 17, 23, 1) + input_data = self._GenerateUniqueRandomInputTensor(input_shape) + # Add some randomness to make input_data not so 'integer' + input_data += self._PRNG.random_sample(input_shape) + pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1) + output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)] + overlapping = True + pseudo_random = False + + with self.test_session() as _: + input_tensor = tf.constant(input_data, shape=input_shape) + output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool( + input_tensor, + pooling_ratio, + pseudo_random=pseudo_random, + overlapping=overlapping, + deterministic=True, + seed=self._SEED, + seed2=self._SEED2) + # error_margin and delta setting is similar to max_pool_grad. + error_margin = 1e-3 + gradient_error = tf.test.compute_gradient_error( + input_tensor, + input_shape, + output_tensor, + output_shape, + x_init_value=input_data.reshape(input_shape), + delta=1e-2) + self.assertLess(gradient_error, error_margin) + + def testWhenRepeatedMaxValueInPoolingRegion(self): + """Test when there's repeating value in pooling region. + + There's no formal definition for what the gradient should be when there're + multiple max value within a pooling cell. Such as + | 1 5 | + | 5 3 | + The expected result depends heavily on implementation, if someone swap the + order of a nested for loop when walking through the tensor, result would be + very different. + + The goal of this test is to alert when someone else change the + implementation. Current implementation scans row-by-row. + """ + input_data = [5.0, 4.0, 6.0, 7.0, + 3.0, 5.0, 9.0, 6.0, + 8.0, 8.0, 9.0, 5.0, + 7.0, 4.0, 0.0, 0.0] # pyformat: disable + input_size = [1, 4, 4, 1] + output_backprop = [12.0, 15.0, + 17.0, -5.0, + 6.0, 21.0] # pyformat: disable + row_seq = [0, 1, 3, 4] + col_seq = [0, 2, 4] + output_data_not_overlapping = [5.0, 7.0, + 8.0, 9.0, + 7.0, 0.0] # pyformat: disable + output_data_overlapping = [9.0, 9.0, + 9.0, 9.0, + 7.0, 0.0] # pyformat: disable + output_size = [1, 3, 2, 1] + expected_input_backprop_not_overlapping = np.reshape( + [12.0, 0.0, 0.0, 15.0, + 0.0, 0.0, -5.0, 0.0, + 17.0, 0.0, 0.0, 0.0, + 6.0, 0.0, 21.0, 0.0], + input_size) # pyformat: disable + expected_input_backprop_overlapping = np.reshape( + [0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 39.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + 6.0, 0.0, 21.0, 0.0], + input_size) # pyformat: disable + with self.test_session() as _: + # Test when overlapping is False + input_tensor = tf.constant(input_data, shape=input_size) + output_tensor = tf.constant(output_data_not_overlapping, + shape=output_size) + grad = tf.constant(output_backprop, shape=output_size) + r = gen_nn_ops._fractional_max_pool_grad( + input_tensor, + output_tensor, + grad, + row_seq, + col_seq, + overlapping=False) + input_backprop_not_overlapping = r.eval() + self.assertShapeEqual( + np.reshape(expected_input_backprop_not_overlapping, input_size), r) + self.assertAllClose(expected_input_backprop_not_overlapping, + input_backprop_not_overlapping) + # Test when overlapping is True + output_tensor = tf.constant(output_data_overlapping, shape=output_size) + r = gen_nn_ops._fractional_max_pool_grad( + input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True) + input_backprop_overlapping = r.eval() + self.assertShapeEqual( + np.reshape(expected_input_backprop_overlapping, input_size), r) + self.assertAllClose(expected_input_backprop_overlapping, + input_backprop_overlapping) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 3a94b9f4e39..e243720cabf 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -181,6 +181,8 @@ AvgPool MaxPool Softmax LogSoftmax +FractionalAvgPoolGrad +FractionalMaxPoolGrad # parsing_ops ParseExample diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 941bbfd271b..3c3c0cca41a 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -132,6 +132,8 @@ to the `Convolution` section for details about the padding calculation. @@max_pool_with_argmax @@avg_pool3d @@max_pool3d +@@fractional_avg_pool +@@fractional_max_pool ## Morphological filtering diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index a396f06ebad..561ba6d5bb7 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -361,6 +361,53 @@ def _MaxPoolGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("FractionalMaxPool") +def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): + """Returns gradient for FractionalMaxPool. + + Since FractionalMaxPool has three outputs, there are three gradients passed in + for each of the outputs. Only the first one is useful, the other two gradients + are empty. + + Args: + op: The FractionalMaxPoolOp. + grad_0: Gradient with respect to op.outputs[0] + unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. + unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. + + Returns: + Input backprop for FractionalMaxPool op. + """ + # pylint: disable=protected-access + return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0], + grad_0, op.outputs[1], + op.outputs[2], + op.get_attr("overlapping")) + + +@ops.RegisterGradient("FractionalAvgPool") +def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): + """Returns gradient for FractionalAvgPool. + + Since FractionalAvgPool has three outputs, there are three gradients passed in + for each of the outputs. Only the first one is useful, the other two gradients + are empty. + + Args: + op: The FractionalAvgPoolOp. + grad_0: Gradient with respect to op.outputs[0] + unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. + unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. + + Returns: + Input backprop for FractionalAvgPool op. + """ + # pylint: disable=protected-access + return gen_nn_ops._fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, + op.outputs[1], op.outputs[2], + op.get_attr("overlapping")) + + @ops.RegisterGradient("BatchNormWithGlobalNormalization") def _BatchNormWithGlobalNormalizationGrad(op, grad): """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index f3805c3f2d7..f5f5aedc01a 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -941,6 +941,39 @@ def _AvgPoolGradShape(op): return [tensor_shape.unknown_shape(ndims=4)] +@ops.RegisterShape("FractionalMaxPool") +@ops.RegisterShape("FractionalAvgPool") +def _fractional_pool_shape(op): + input_dims = op.inputs[0].get_shape().with_rank(4).as_list() + pooling_ratio = op.get_attr("pooling_ratio") + output_dims = np.divide(input_dims, pooling_ratio).astype(int) + return [ + # output. + tensor_shape.TensorShape(output_dims), + # row_pooling_sequence. + tensor_shape.TensorShape([output_dims[1]]), + # col_pooling_sequence. + tensor_shape.TensorShape([output_dims[2]]) + ] + + +@ops.RegisterShape("FractionalMaxPoolGrad") +def _fractional_max_pool_grad_shape(op): + """Shape function for the FractionalMaxPoolGrad op.""" + orig_input_shape = op.inputs[0].get_shape().with_rank(4) + return [orig_input_shape] + + +@ops.RegisterShape("FractionalAvgPoolGrad") +def _fractional_avg_pool_grad_shape(op): + """Shape function for the FractionalAvgPoolGrad op.""" + orig_input_shape = tensor_util.constant_value(op.inputs[0]) + if orig_input_shape is not None: + return [tensor_shape.TensorShape(orig_input_shape.tolist())] + else: + return [tensor_shape.unknown_shape(ndims=4)] + + @ops.RegisterShape("Conv2DBackpropFilter") def _Conv2DBackpropFilterShape(op): """Shape function for the Conv2DBackpropFilter op."""