From 8b667b7d4bf1b845ce2ccffad9221490afa583ce Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 30 Aug 2016 13:23:06 -0800
Subject: [PATCH] Adds fractional_max_pool and fractional_avg_pool ops. Fixes
 #2953. Change: 131754627

---
 tensorflow/core/kernels/BUILD                 |   4 +
 .../core/kernels/fractional_avg_pool_op.cc    | 354 +++++++++++
 .../core/kernels/fractional_max_pool_op.cc    | 381 ++++++++++++
 .../core/kernels/fractional_pool_common.cc    | 134 ++++
 .../core/kernels/fractional_pool_common.h     |  78 +++
 tensorflow/core/ops/nn_ops.cc                 | 201 ++++++
 .../shard3/tf.nn.fractional_max_pool.md       |  82 +++
 .../shard7/tf.nn.fractional_avg_pool.md       |  57 ++
 tensorflow/g3doc/api_docs/python/index.md     |   2 +
 tensorflow/g3doc/api_docs/python/nn.md        | 145 +++++
 tensorflow/python/kernel_tests/BUILD          |   2 +
 .../fractional_avg_pool_op_test.py            | 521 ++++++++++++++++
 .../fractional_max_pool_op_test.py            | 582 ++++++++++++++++++
 tensorflow/python/ops/hidden_ops.txt          |   2 +
 tensorflow/python/ops/nn.py                   |   2 +
 tensorflow/python/ops/nn_grad.py              |  47 ++
 tensorflow/python/ops/nn_ops.py               |  33 +
 17 files changed, 2627 insertions(+)
 create mode 100644 tensorflow/core/kernels/fractional_avg_pool_op.cc
 create mode 100644 tensorflow/core/kernels/fractional_max_pool_op.cc
 create mode 100644 tensorflow/core/kernels/fractional_pool_common.cc
 create mode 100644 tensorflow/core/kernels/fractional_pool_common.h
 create mode 100644 tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md
 create mode 100644 tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md
 create mode 100644 tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
 create mode 100644 tensorflow/python/kernel_tests/fractional_max_pool_op_test.py

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index aeda266a70c..f279e24af79 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1448,6 +1448,9 @@ tf_kernel_library(
     srcs = [
         "avgpooling_op.cc",
         "cudnn_pooling_gpu.cc",
+        "fractional_avg_pool_op.cc",
+        "fractional_max_pool_op.cc",
+        "fractional_pool_common.cc",
         "maxpooling_op.cc",
         "pooling_ops_3d.cc",
         "pooling_ops_common.cc",
@@ -1455,6 +1458,7 @@ tf_kernel_library(
     hdrs = [
         "avgpooling_op.h",
         "cudnn_pooling_gpu.h",
+        "fractional_pool_common.h",
         "maxpooling_op.h",
         "pooling_ops_common.h",
     ],
diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc
new file mode 100644
index 00000000000..a983d9362cc
--- /dev/null
+++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc
@@ -0,0 +1,354 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <cmath>
+#include <random>
+#include <vector>
+
+#include "tensorflow/core/kernels/fractional_pool_common.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+class FractionalAvgPoolOp : public OpKernel {
+ public:
+  explicit FractionalAvgPoolOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
+    OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
+    OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
+    OP_REQUIRES(context, pooling_ratio_.size() == 4,
+                errors::InvalidArgument(
+                    "pooling_ratio field must specify 4 dimensions"));
+    OP_REQUIRES(
+        context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
+        errors::Unimplemented("Fractional average pooling is not yet "
+                              "supported on the batch nor channel dimension."));
+    OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
+    pooling_region_generated_ = false;
+    // Initialize philox random generator.
+    OP_REQUIRES_OK(context, generator_.Init(context));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        ConstEigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenMatrixMap;
+
+    constexpr int tensor_in_and_out_dims = 4;
+
+    const Tensor& tensor_in = context->input(0);
+    OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
+                errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      input_size_.push_back(tensor_in.dim_size(i));
+    }
+    // Output size.
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      output_size_.push_back(
+          static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
+      DCHECK_GT(output_size_[i], 0);
+    }
+
+    // Generate pooling sequence.
+    std::vector<int64> row_cum_seq;
+    std::vector<int64> col_cum_seq;
+    if (deterministic_) {
+      if (pooling_region_generated_) {
+        row_cum_seq = row_cum_seq_;
+        col_cum_seq = col_cum_seq_;
+      } else {
+        row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
+                                              &generator_, pseudo_random_);
+        col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
+                                              &generator_, pseudo_random_);
+        mutex_lock lock(mu_);
+        row_cum_seq_ = row_cum_seq;
+        col_cum_seq_ = col_cum_seq;
+        pooling_region_generated_ = true;
+      }
+    } else {
+      row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
+                                            &generator_, pseudo_random_);
+      col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
+                                            &generator_, pseudo_random_);
+    }
+
+    // Prepare output.
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       0, TensorShape({output_size_[0], output_size_[1],
+                                       output_size_[2], output_size_[3]}),
+                       &output_tensor));
+    Tensor* output_row_seq_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       1, TensorShape({static_cast<int64>(row_cum_seq.size())}),
+                       &output_row_seq_tensor));
+    Tensor* output_col_seq_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
+                       &output_col_seq_tensor));
+
+    ConstEigenMatrixMap in_mat(
+        tensor_in.flat<T>().data(), input_size_[3],
+        input_size_[2] * input_size_[1] * input_size_[0]);
+
+    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
+                           output_size_[2] * output_size_[1] * output_size_[0]);
+    // out_count corresponds to number of elements in each pooling cell.
+    Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
+
+    // Initializes the output tensor and out_count with 0.
+    out_mat.setZero();
+    out_count.setZero();
+
+    auto output_row_seq_flat = output_row_seq_tensor->flat<int64>();
+    auto output_col_seq_flat = output_col_seq_tensor->flat<int64>();
+
+    // Set output tensors.
+    for (int i = 0; i < row_cum_seq.size(); ++i) {
+      output_row_seq_flat(i) = row_cum_seq[i];
+    }
+
+    for (int i = 0; i < col_cum_seq.size(); ++i) {
+      output_col_seq_flat(i) = col_cum_seq[i];
+    }
+
+    // For both input and output,
+    // 0: batch
+    // 1: row / row
+    // 2: col / col
+    // 3: depth / channel
+    const int64 row_max = input_size_[1] - 1;
+    const int64 col_max = input_size_[2] - 1;
+    for (int64 b = 0; b < input_size_[0]; ++b) {
+      // row sequence.
+      for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
+        // row start and end.
+        const int64 row_start = row_cum_seq[hs];
+        int64 row_end =
+            overlapping_ ? row_cum_seq[hs + 1] : row_cum_seq[hs + 1] - 1;
+        row_end = std::min(row_end, row_max);
+
+        // col sequence.
+        for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
+          const int64 out_offset =
+              (b * output_size_[1] + hs) * output_size_[2] + ws;
+          // col start and end.
+          const int64 col_start = col_cum_seq[ws];
+          int64 col_end =
+              overlapping_ ? col_cum_seq[ws + 1] : col_cum_seq[ws + 1] - 1;
+          col_end = std::min(col_end, col_max);
+          for (int64 h = row_start; h <= row_end; ++h) {
+            for (int64 w = col_start; w <= col_end; ++w) {
+              const int64 in_offset =
+                  (b * input_size_[1] + h) * input_size_[2] + w;
+              out_mat.col(out_offset) += in_mat.col(in_offset);
+              out_count(out_offset)++;
+            }
+          }
+        }
+      }
+    }
+    DCHECK_GT(out_count.minCoeff(), 0);
+    out_mat.array().rowwise() /= out_count.transpose().array();
+  }
+
+ private:
+  bool deterministic_;
+  // meaningful only when deterministic_ is true.
+  mutex mu_;
+  std::vector<int64> row_cum_seq_;
+  std::vector<int64> col_cum_seq_;
+  bool pooling_region_generated_;
+
+  std::vector<int32> input_size_;
+  std::vector<int32> output_size_;
+  std::vector<float> pooling_ratio_;
+  bool pseudo_random_;
+  bool overlapping_;
+  GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER_FRACTIONALAVGPOOL(type)                                      \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("FractionalAvgPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      FractionalAvgPoolOp<type>)
+
+REGISTER_FRACTIONALAVGPOOL(int32);
+REGISTER_FRACTIONALAVGPOOL(int64);
+REGISTER_FRACTIONALAVGPOOL(float);
+REGISTER_FRACTIONALAVGPOOL(double);
+
+#undef REGISTER_FRACTIONALAVGPOOL
+
+template <class T>
+class FractionalAvgPoolGradOp : public OpKernel {
+ public:
+  explicit FractionalAvgPoolGradOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Here's the basic idea:
+    // Batch and depth dimension are independent from row and col dimension. And
+    // because FractionalAvgPool currently only support pooling along row and
+    // col, we can basically think of this 4D tensor backpropagation as
+    // operation of a series of 2D planes.
+    //
+    // For each element of a 'slice' (2D plane) of output_backprop, we need to
+    // figure out its contributors when doing FractionalAvgPool operation. This
+    // can be done based on row_pooling_sequence, col_pooling_seq and
+    // overlapping.
+    // Once we figure out the original contributors, we just need to evenly
+    // divide the value of this element among these contributors.
+    //
+    // Internally, we divide the out_backprop tensor and store it in a temparary
+    // tensor of double type. And cast it to the corresponding type.
+    typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        ConstEigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenDoubleMatrixMap;
+
+    // Grab the inputs.
+    const Tensor& orig_input_tensor_shape = context->input(0);
+    OP_REQUIRES(context, orig_input_tensor_shape.dims() == 1 &&
+                             orig_input_tensor_shape.NumElements() == 4,
+                errors::InvalidArgument("original input tensor shape must be"
+                                        "1-dimensional and 4 elements"));
+    const Tensor& out_backprop = context->input(1);
+    const Tensor& row_seq_tensor = context->input(2);
+    const Tensor& col_seq_tensor = context->input(3);
+
+    const int64 out_batch = out_backprop.dim_size(0);
+    const int64 out_rows = out_backprop.dim_size(1);
+    const int64 out_cols = out_backprop.dim_size(2);
+    const int64 out_depth = out_backprop.dim_size(3);
+
+    auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
+    auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
+    auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
+
+    const int64 in_batch = orig_input_tensor_shape_flat(0);
+    const int64 in_rows = orig_input_tensor_shape_flat(1);
+    const int64 in_cols = orig_input_tensor_shape_flat(2);
+    const int64 in_depth = orig_input_tensor_shape_flat(3);
+
+    constexpr int tensor_in_and_out_dims = 4;
+    // Transform orig_input_tensor_shape into TensorShape
+    TensorShape in_shape;
+    for (auto i = 0; i < tensor_in_and_out_dims; ++i) {
+      in_shape.AddDim(orig_input_tensor_shape_flat(i));
+    }
+
+    // Create intermediate in_backprop.
+    Tensor in_backprop_tensor_temp;
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(DataTypeToEnum<double>::v(), in_shape,
+                                          &in_backprop_tensor_temp));
+    in_backprop_tensor_temp.flat<double>().setZero();
+    // Transform 4D tensor to 2D matrix.
+    EigenDoubleMatrixMap in_backprop_tensor_temp_mat(
+        in_backprop_tensor_temp.flat<double>().data(), in_depth,
+        in_cols * in_rows * in_batch);
+    ConstEigenMatrixMap out_backprop_mat(out_backprop.flat<T>().data(),
+                                         out_depth,
+                                         out_cols * out_rows * out_batch);
+    // Loop through each element of out_backprop and evenly distribute the
+    // element to the corresponding pooling cell.
+    const int64 in_max_row_index = in_rows - 1;
+    const int64 in_max_col_index = in_cols - 1;
+    for (int64 b = 0; b < out_batch; ++b) {
+      for (int64 r = 0; r < out_rows; ++r) {
+        const int64 in_row_start = row_seq_tensor_flat(r);
+        int64 in_row_end = overlapping_ ? row_seq_tensor_flat(r + 1)
+                                        : row_seq_tensor_flat(r + 1) - 1;
+        in_row_end = std::min(in_row_end, in_max_row_index);
+        for (int64 c = 0; c < out_cols; ++c) {
+          const int64 in_col_start = col_seq_tensor_flat(c);
+          int64 in_col_end = overlapping_ ? col_seq_tensor_flat(c + 1)
+                                          : col_seq_tensor_flat(c + 1) - 1;
+          in_col_end = std::min(in_col_end, in_max_col_index);
+
+          const int64 num_elements_in_pooling_cell =
+              (in_row_end - in_row_start + 1) * (in_col_end - in_col_start + 1);
+          const int64 out_index = (b * out_rows + r) * out_cols + c;
+          // Now we can evenly distribute out_backprop(b, h, w, *) to
+          // in_backprop(b, hs:he, ws:we, *).
+          for (int64 in_r = in_row_start; in_r <= in_row_end; ++in_r) {
+            for (int64 in_c = in_col_start; in_c <= in_col_end; ++in_c) {
+              const int64 in_index = (b * in_rows + in_r) * in_cols + in_c;
+              // Walk through each channel (depth).
+              for (int64 d = 0; d < out_depth; ++d) {
+                const double out_backprop_element = static_cast<double>(
+                    out_backprop_mat.coeffRef(d, out_index));
+                double& in_backprop_ref =
+                    in_backprop_tensor_temp_mat.coeffRef(d, in_index);
+                in_backprop_ref +=
+                    out_backprop_element / num_elements_in_pooling_cell;
+              }
+            }
+          }
+        }
+      }
+    }
+
+    // Depending on the type, cast double to type T.
+    Tensor* in_backprop_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, in_shape, &in_backprop_tensor));
+    auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>();
+    auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>();
+    for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) {
+      in_backprop_tensor_flat(i) =
+          static_cast<T>(in_backprop_tensor_temp_flat(i));
+    }
+  }
+
+ private:
+  bool overlapping_;
+};
+
+#define REGISTER_FRACTIONALAVGPOOLGRAD(type)              \
+  REGISTER_KERNEL_BUILDER(Name("FractionalAvgPoolGrad")   \
+                              .Device(DEVICE_CPU)         \
+                              .TypeConstraint<type>("T"), \
+                          FractionalAvgPoolGradOp<type>)
+
+REGISTER_FRACTIONALAVGPOOLGRAD(int32);
+REGISTER_FRACTIONALAVGPOOLGRAD(int64);
+REGISTER_FRACTIONALAVGPOOLGRAD(float);
+REGISTER_FRACTIONALAVGPOOLGRAD(double);
+
+#undef REGISTER_FRACTIONALAVGPOOLGRAD
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc
new file mode 100644
index 00000000000..482491b504a
--- /dev/null
+++ b/tensorflow/core/kernels/fractional_max_pool_op.cc
@@ -0,0 +1,381 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <cmath>
+#include <random>
+#include <vector>
+
+#include "tensorflow/core/kernels/fractional_pool_common.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+class FractionalMaxPoolOp : public OpKernel {
+ public:
+  explicit FractionalMaxPoolOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
+    OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
+    OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
+
+    OP_REQUIRES(context, pooling_ratio_.size() == 4,
+                errors::InvalidArgument("pooling_ratio field must "
+                                        "specify 4 dimensions"));
+
+    OP_REQUIRES(
+        context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
+        errors::Unimplemented("Fractional max pooling is not yet "
+                              "supported on the batch nor channel dimension."));
+
+    OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
+    pooling_region_generated_ = false;
+    // Initialize philox random generator.
+    OP_REQUIRES_OK(context, generator_.Init(context));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        ConstEigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenMatrixMap;
+
+    constexpr int tensor_in_and_out_dims = 4;
+
+    const Tensor& tensor_in = context->input(0);
+    OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
+                errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      input_size_.push_back(tensor_in.dim_size(i));
+    }
+    // Output size.
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      output_size_.push_back(
+          static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
+      DCHECK_GT(output_size_[i], 0);
+    }
+
+    // Generate pooling sequence.
+    std::vector<int64> height_cum_seq;
+    std::vector<int64> width_cum_seq;
+    if (deterministic_) {
+      if (pooling_region_generated_) {
+        height_cum_seq = height_cum_seq_;
+        width_cum_seq = width_cum_seq_;
+      } else {
+        height_cum_seq = GeneratePoolingSequence(
+            input_size_[1], output_size_[1], &generator_, pseudo_random_);
+        width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
+                                                &generator_, pseudo_random_);
+        mutex_lock lock(mu_);
+        height_cum_seq_ = height_cum_seq;
+        width_cum_seq_ = width_cum_seq;
+        pooling_region_generated_ = true;
+      }
+    } else {
+      height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
+                                               &generator_, pseudo_random_);
+      width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
+                                              &generator_, pseudo_random_);
+    }
+
+    // Prepare output.
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       0, TensorShape({output_size_[0], output_size_[1],
+                                       output_size_[2], output_size_[3]}),
+                       &output_tensor));
+    Tensor* output_height_seq_tensor = nullptr;
+    OP_REQUIRES_OK(
+        context,
+        context->allocate_output(
+            1, TensorShape({static_cast<int64>(height_cum_seq.size())}),
+            &output_height_seq_tensor));
+    Tensor* output_width_seq_tensor = nullptr;
+    OP_REQUIRES_OK(
+        context, context->allocate_output(
+                     2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
+                     &output_width_seq_tensor));
+
+    ConstEigenMatrixMap in_mat(
+        tensor_in.flat<T>().data(), input_size_[3],
+        input_size_[2] * input_size_[1] * input_size_[0]);
+
+    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
+                           output_size_[2] * output_size_[1] * output_size_[0]);
+
+    // Initializes the output tensor with MIN<T>.
+    output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
+
+    auto output_height_seq_flat = output_height_seq_tensor->flat<int64>();
+    auto output_width_seq_flat = output_width_seq_tensor->flat<int64>();
+
+    // Set output tensors.
+    for (int i = 0; i < height_cum_seq.size(); ++i) {
+      output_height_seq_flat(i) = height_cum_seq[i];
+    }
+
+    for (int i = 0; i < width_cum_seq.size(); ++i) {
+      output_width_seq_flat(i) = width_cum_seq[i];
+    }
+
+    // For both input and output,
+    // 0: batch
+    // 1: height / row
+    // 2: width / col
+    // 3: depth / channel
+    const int64 height_max = input_size_[1] - 1;
+    const int64 width_max = input_size_[2] - 1;
+    for (int64 b = 0; b < input_size_[0]; ++b) {
+      // height sequence.
+      for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
+        // height start and end.
+        const int64 height_start = height_cum_seq[hs];
+        int64 height_end =
+            overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
+        height_end = std::min(height_end, height_max);
+
+        // width sequence.
+        for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
+          const int64 out_offset =
+              (b * output_size_[1] + hs) * output_size_[2] + ws;
+          // width start and end.
+          const int64 width_start = width_cum_seq[ws];
+          int64 width_end =
+              overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
+          width_end = std::min(width_end, width_max);
+          for (int64 h = height_start; h <= height_end; ++h) {
+            for (int64 w = width_start; w <= width_end; ++w) {
+              const int64 in_offset =
+                  (b * input_size_[1] + h) * input_size_[2] + w;
+              out_mat.col(out_offset) =
+                  out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
+            }
+          }
+        }
+      }
+    }
+  }
+
+ private:
+  bool deterministic_;
+  // meaningful only when deterministic_ is true.
+  mutex mu_;
+  std::vector<int64> height_cum_seq_;
+  std::vector<int64> width_cum_seq_;
+  bool pooling_region_generated_;
+
+  std::vector<int32> input_size_;
+  std::vector<int32> output_size_;
+  std::vector<float> pooling_ratio_;
+  bool pseudo_random_;
+  bool overlapping_;
+  GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER_FRACTIONALMAXPOOL(type)                                      \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      FractionalMaxPoolOp<type>)
+
+REGISTER_FRACTIONALMAXPOOL(int32);
+REGISTER_FRACTIONALMAXPOOL(int64);
+REGISTER_FRACTIONALMAXPOOL(float);
+REGISTER_FRACTIONALMAXPOOL(double);
+
+#undef REGISTER_FRACTIONALMAXPOOL
+
+static const int kInvalidMaxPoolingIndex = -1;
+
+template <class T>
+class FractionalMaxPoolGradOp : public OpKernel {
+ public:
+  explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // There are two steps when calculating gradient for FractionalMaxPool.
+    // 1) Walk through the process of calculating fractional pooling given
+    //    pooling region; however, in the process, keep track of where the max
+    //    element comes from. (arg_max)
+    // 2) Populate the value of out_backprop to where arg_max indicates. If
+    //    we support overlapping, it is likely to have multiple out_backprop[i]
+    //    propagates back to the same arg_max value.
+    typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        ConstEigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenMatrixMap;
+    typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
+        EigenIndexMatrixMap;
+
+    const Tensor& tensor_in = context->input(0);
+    const Tensor& tensor_out = context->input(1);
+    const Tensor& out_backprop = context->input(2);
+    const Tensor& height_seq_tensor = context->input(3);
+    const Tensor& width_seq_tensor = context->input(4);
+
+    // Just to make it similar to FractionalMaxPoolOp.
+    constexpr int tensor_in_and_out_dims = 4;
+    std::vector<int64> input_size;
+    std::vector<int64> output_size;
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      input_size.push_back(tensor_in.dim_size(i));
+    }
+    for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+      output_size.push_back(tensor_out.dim_size(i));
+    }
+
+    // ---------
+    // Step 1
+    // ---------
+    Tensor tensor_out_dup;
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(DataTypeToEnum<T>::v(),
+                                          tensor_out.shape(), &tensor_out_dup));
+    Tensor tensor_out_arg_max;
+    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
+                                                   tensor_out.shape(),
+                                                   &tensor_out_arg_max));
+    // Find arg_max for each tensor_out
+    ConstEigenMatrixMap tensor_in_mat(
+        tensor_in.flat<T>().data(), input_size[3],
+        input_size[2] * input_size[1] * input_size[0]);
+    EigenMatrixMap tensor_out_dup_mat(
+        tensor_out_dup.flat<T>().data(), output_size[3],
+        output_size[2] * output_size[1] * output_size[0]);
+    EigenIndexMatrixMap tensor_out_arg_max_mat(
+        tensor_out_arg_max.flat<int64>().data(), output_size[3],
+        output_size[2] * output_size[1] * output_size[0]);
+
+    tensor_out_arg_max.flat<int64>().setConstant(kInvalidMaxPoolingIndex);
+    // Initializes the duplicate output tensor with MIN<T>.
+    tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
+
+    auto height_seq_tensor_flat = height_seq_tensor.flat<int64>();
+    auto width_seq_tensor_flat = width_seq_tensor.flat<int64>();
+
+    // Now walk through the process of fractional max pooling again.
+    // For both input and output,
+    // 0: batch
+    // 1: height / row
+    // 2: width / col
+    // 3: depth / channel
+    const int64 height_max = input_size[1] - 1;
+    const int64 width_max = input_size[2] - 1;
+    for (int64 b = 0; b < input_size[0]; ++b) {
+      // height sequence.
+      for (int64 hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
+        // height start and end.
+        const int64 height_start = height_seq_tensor_flat(hs);
+        int64 height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
+                                        : height_seq_tensor_flat(hs + 1) - 1;
+        height_end = std::min(height_end, height_max);
+
+        // width sequence.
+        for (int64 ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
+          const int64 out_index =
+              (b * output_size[1] + hs) * output_size[2] + ws;
+          // width start and end.
+          const int64 width_start = width_seq_tensor_flat(ws);
+          int64 width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
+                                         : width_seq_tensor_flat(ws + 1) - 1;
+          width_end = std::min(width_end, width_max);
+          for (int64 h = height_start; h <= height_end; ++h) {
+            for (int64 w = width_start; w <= width_end; ++w) {
+              const int64 in_index =
+                  (b * input_size[1] + h) * input_size[2] + w;
+              // Walk through each channel (depth).
+              for (int64 d = 0; d < input_size[3]; ++d) {
+                const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
+                T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
+                int64& out_arg_max_ref =
+                    tensor_out_arg_max_mat.coeffRef(d, out_index);
+                if (output_ref < input_ref ||
+                    out_arg_max_ref == kInvalidMaxPoolingIndex) {
+                  output_ref = input_ref;
+                  int input_offset = in_index * input_size[3] + d;
+                  out_arg_max_ref = input_offset;
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+
+    // Check tensor_out_dup is the same as tensor_out.
+    ConstEigenMatrixMap tensor_out_mat(
+        tensor_out.flat<T>().data(), output_size[3],
+        output_size[2] * output_size[1] * output_size[0]);
+    const int64 num_reshaped_cols =
+        output_size[2] * output_size[1] * output_size[0];
+    for (int64 i = 0; i < num_reshaped_cols; ++i) {
+      for (int64 j = 0; j < output_size[3]; ++j) {
+        DCHECK_EQ(tensor_out_dup_mat(j, i), tensor_out_mat(j, i));
+      }
+    }
+
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, tensor_in.shape(), &output));
+    output->flat<T>().setZero();
+
+    auto out_backprop_flat = out_backprop.flat<T>();
+    auto input_backprop_flat = output->flat<T>();
+    auto out_arg_max_flat = tensor_out_arg_max.flat<int64>();
+    int num_total_outputs = out_backprop_flat.size();
+    int num_total_inputs = input_backprop_flat.size();
+
+    for (int index = 0; index < num_total_outputs; ++index) {
+      int input_backprop_index = out_arg_max_flat(index);
+      // According to maxpooling_op.cc, the performance impact below is small.
+      CHECK(input_backprop_index >= 0 &&
+            input_backprop_index < num_total_inputs)
+          << "Invalid input backprop index: " << input_backprop_index << ", "
+          << num_total_inputs;
+      input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
+    }
+  }
+
+ private:
+  bool overlapping_;
+};
+
+#define REGISTER_FRACTIONALMAXPOOLGRAD(type)              \
+  REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad")   \
+                              .Device(DEVICE_CPU)         \
+                              .TypeConstraint<type>("T"), \
+                          FractionalMaxPoolGradOp<type>)
+
+REGISTER_FRACTIONALMAXPOOLGRAD(int32);
+REGISTER_FRACTIONALMAXPOOLGRAD(int64);
+REGISTER_FRACTIONALMAXPOOLGRAD(float);
+REGISTER_FRACTIONALMAXPOOLGRAD(double);
+
+#undef REGISTER_FRACTIONALMAXPOOLGRAD
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/fractional_pool_common.cc b/tensorflow/core/kernels/fractional_pool_common.cc
new file mode 100644
index 00000000000..f3a3dfda8df
--- /dev/null
+++ b/tensorflow/core/kernels/fractional_pool_common.cc
@@ -0,0 +1,134 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/fractional_pool_common.h"
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+static std::vector<int64> GeneratePoolingSequencePseudoRandom(
+    int input_length, int output_length, GuardedPhiloxRandom* generator) {
+  std::vector<int64> cum_seq(output_length + 1, 0);
+  std::vector<int64> diff(output_length, 0);
+
+  double alpha = static_cast<double>(input_length) / output_length;
+  int k = input_length / output_length;
+
+  // In the paper [1], author proposes the following procedure to generate a
+  // pseudo random pooling region:
+  //   1) Set a_0 = 1, a_Nout = Nin;
+  //   2) a_i = ceil(alpha*(u+i))
+  //      in which, i = 1, 2, ... , Nout-1
+  //                u is a random number in (0,1) for all i
+  //                alpha = Nin/Nout in (1,2)
+  // The sequence {a_i} should satisfy a_i-a_{i-1} = 1 or 2
+  // Note: for step 1), it makes more sense to make a_Nout = Nin+1, that way,
+  //    a_i-a_{i-1} = 1 or 2 is also true for i = Nout.
+  //
+  // However, there are choices of alpha and u that will make
+  // a_i - a_{i-1} > 2. This happens at the left boundary. For example, with
+  // alpha = 1.732, u = 0.8, then a_1 = 4, a_1-a_0 = 3.
+  // This is why u_max1 is needed, i.e. u is a random number in (0,u_max1)
+  // instead of (0,1).
+  // Define k = ceil(alpha)-1, then we require:
+  //   a_1 = alpha*(u+1) <= a_0+(k+1)
+  // ===> This gives u_max1 = (k+2)/alpha - 1.
+  //
+  // In addition, when extending the pooling sequence generation process for
+  // alpha beyond (1,2), e.g. (k,k+1); a check on the right boundary is also
+  // needed to make sure the last gap a_Nout-a_{Nout-1} >= k. Solving it gives
+  // u_max2 = (Nin+1-k)/alpha - (Nout-1)
+  // Here is an example where u > u_max2, alpha = 2.3, u = 0.7, u_max2 = 0.565,
+  // Nin = 23, Nout = 10; the sequence
+  // from a_0 to a_10 is:
+  // [1, 4, 7, 9, 11, 14, 16, 18, 21, 23, 24]
+  // The last gap is only 1.
+  //
+  // [1]: https://arxiv.org/abs/1412.6071
+  double u_max1 = (k + 2) / alpha - 1;
+  double u_max2 = (input_length + 1 - k) / alpha - (output_length - 1);
+  double max_u = std::min(u_max1, u_max2);
+
+  // Generate random number in parallel.
+  auto local_gen = generator->ReserveSamples32(2);
+  random::SimplePhilox random(&local_gen);
+  const double u = random.RandDouble() * max_u;
+
+  cum_seq[0] = 1;
+  cum_seq[output_length] = input_length + 1;
+  for (int i = 1; i < output_length; ++i) {
+    cum_seq[i] = static_cast<int>(ceil(alpha * (i + u)));
+  }
+
+  for (int i = 0; i < output_length; ++i) {
+    diff[i] = cum_seq[i + 1] - cum_seq[i];
+  }
+
+  return diff;
+}
+
+static std::vector<int64> GeneratePoolingSequenceRandom(
+    int input_length, int output_length, GuardedPhiloxRandom* generator) {
+  int k = input_length / output_length;
+  int num_random_spot = input_length % output_length;
+  std::vector<int64> diff(output_length, k);
+
+  for (int i = 0; i < num_random_spot; ++i) {
+    diff[i] += 1;
+  }
+
+  // Randomly shuffle this vector.
+  auto local_gen = generator->ReserveSamples32(diff.size());
+  random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen);
+  const auto uniform = [&single](uint32 n) { return single() % n; };
+  RandomShuffle(diff.begin(), diff.end(), uniform);
+
+  return diff;
+}
+
+std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
+                                           GuardedPhiloxRandom* generator,
+                                           bool pseudo_random) {
+  std::vector<int64> diff;
+  // This is a case that regular pooling can handle, just return diff with
+  // each element input_length/output_length.
+  if (input_length % output_length == 0) {
+    diff = std::vector<int64>(output_length, input_length / output_length);
+  }
+
+  if (pseudo_random) {
+    diff = GeneratePoolingSequencePseudoRandom(input_length, output_length,
+                                               generator);
+  } else {
+    diff =
+        GeneratePoolingSequenceRandom(input_length, output_length, generator);
+  }
+
+  // Sanity check.
+  int k = input_length / output_length;
+  for (int i = 0; i < output_length; ++i) {
+    // k<= diff[i] <= k+1.
+    DCHECK_GE(diff[i], k);
+    DCHECK_LE(diff[i], k + 1);
+  }
+
+  // Return cumulative sequence.
+  std::vector<int64> cum_seq(output_length + 1, 0);
+  for (int i = 1; i < cum_seq.size(); ++i) {
+    cum_seq[i] = cum_seq[i - 1] + diff[i - 1];
+  }
+  return cum_seq;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h
new file mode 100644
index 00000000000..df0bbbfa066
--- /dev/null
+++ b/tensorflow/core/kernels/fractional_pool_common.h
@@ -0,0 +1,78 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+// Shuffle a container randomly, copied from random_shuffle_op.cc
+template <class Iter, class Random>
+static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) {
+  if (first == last) {
+    return;
+  }
+  const auto stop = last - 1;
+  for (auto i = first; i != stop; ++i) {
+    using std::iter_swap;
+    iter_swap(i, i + uniform(last - i));
+  }
+}
+
+// Generate pooling sequence for fractional pooling along one dimension.
+//
+// Regular max/avg pooling can be viewed as a special case, in which given the
+//     * input_length: e.g. 10
+//     * output_length: e.g. 5
+// it will generate pooling sequence as
+//     diff sequence: [2, 2, 2, 2, 2]
+// or as
+//     cumulative sequence: [0, 2, 4, 6, 8, 10]
+//
+// In the case of fractional pooling, input_length is not an integer multiple of
+// output_length, randomness plays a role when generating pooling sequence.
+// There are two type of randomness (random vs pseudo-random) defined in paper:
+// http://arxiv.org/abs/1412.6071
+// You can check the paper for the difference between these two types.
+//
+// In summary, the generated diff sequence satisfy the following properties for
+// both types of randomness:
+//     * length(generated_diff_pooling_sequence) = output_length
+//     * sum(generated_diff_pooling_sequence) = input_length
+//     * Let's define floor(input_length / output_length) = K, then
+//       K <= generated_diff_pooling_sequence[i] <= K+1
+// For example, when input_length = 10, output_length = 6, the followings are
+// valid pooling sequence:
+//     * [1, 2, 2, 1, 2, 2]
+//     * [1, 1, 2, 2, 2, 2]
+// [1, 3, 2, 2, 2, 2] is not valid.
+//
+// Args:
+//   input_length:  See above explanation
+//   output_length:  See above explanation
+//   generator:  Parallel version of random number generator
+//   pseudo_random:  Whether or not use pseudo-random
+// Returns:
+//   pooling_sequence:  This is the cumulative pooling sequence.
+std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
+                                           GuardedPhiloxRandom* generator,
+                                           bool pseudo_random);
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index affbd269669..fe78541d08f 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1515,4 +1515,205 @@ values: The `k` largest elements along each last dimensional slice.
 indices: The indices of `values` within the last dimension of `input`.
 )doc");
 
+// --------------------------------------------------------------------------
+
+REGISTER_OP("FractionalMaxPool")
+    .Input("value: T")
+    .Output("output: T")
+    .Output("row_pooling_sequence: int64")
+    .Output("col_pooling_sequence: int64")
+    .Attr("pooling_ratio: list(float) >=4")
+    .Attr("pseudo_random: bool = false")
+    .Attr("overlapping: bool = false")
+    .Attr("deterministic: bool = false")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
+    .Attr("T: {float, double, int32, int64}")
+    .Doc(R"doc(
+Performs fractional max pooling on the input.
+
+Fractional max pooling is slightly different than regular max pooling.  In
+regular max pooling, you downsize an input set by taking the maximum value of
+smaller N x N subsections of the set (often 2x2), and try to reduce the set by
+a factor of N, where N is an integer.  Fractional max pooling, as you might
+expect from the word "fractional", means that the overall reduction ratio N
+does not have to be an integer.
+
+The sizes of the pooling regions are generated randomly but are fairly uniform.
+For example, let's look at the height dimension, and the constraints on the
+list of rows that will be pool boundaries.
+
+First we define the following:
+
+1.  input_row_length : the number of rows from the input set
+2.  output_row_length : which will be smaller than the input
+3.  alpha = input_row_length / output_row_length : our reduction ratio
+4.  K = floor(alpha)
+5.  row_pooling_sequence : this is the result list of pool boundary rows
+
+Then, row_pooling_sequence should satisfy:
+
+1.  a[0] = 0 : the first value of the sequence is 0
+2.  a[end] = input_row_length : the last value of the sequence is the size
+3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
+4.  length(row_pooling_sequence) = output_row_length+1
+
+For more details on fractional max pooling, see this paper:
+[Benjamin Graham, Fractional Max-Pooling]
+(http://arxiv.org/abs/1412.6071)
+
+value: 4-D with shape `[batch, height, width, channels]`.
+pooling_ratio: Pooling ratio for each dimension of `value`, currently only
+  supports row and col dimension and should be >= 1.0. For example, a valid
+  pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+  must be 1.0 because we don't allow pooling on batch and channels
+  dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+  respectively.
+pseudo_random: When set to True, generates the pooling sequence in a
+  pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+  Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+  difference between pseudorandom and random.
+overlapping: When set to True, it means when pooling, the values at the boundary
+  of adjacent pooling cells are used by both cells. For example:
+
+  `index  0  1  2  3  4`
+
+  `value  20 5  16 3  7`
+
+  If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+  The result would be [20, 16] for fractional max pooling.
+deterministic: When set to True, a fixed pooling region will be used when
+  iterating over a FractionalMaxPool node in the computation graph. Mainly used
+  in unit test to make FractionalMaxPool deterministic.
+seed: If either seed or seed2 are set to be non-zero, the random number
+  generator is seeded by the given seed.  Otherwise, it is seeded by a
+  random seed.
+seed2: An second seed to avoid seed collision.
+output: output tensor after fractional max pooling.
+row_pooling_sequence: row pooling sequence, needed to calculate gradient.
+col_pooling_sequence: column pooling sequence, needed to calculate gradient.
+)doc");
+
+REGISTER_OP("FractionalMaxPoolGrad")
+    .Input("orig_input: T")
+    .Input("orig_output: T")
+    .Input("out_backprop: T")
+    .Input("row_pooling_sequence: int64")
+    .Input("col_pooling_sequence: int64")
+    .Output("output: T")
+    .Attr("overlapping: bool = false")
+    .Attr("T: {float, double, int32, int64}")
+    .Doc(R"doc(
+Computes gradient of the FractionalMaxPool function.
+
+orig_input: Original input for `fractional_max_pool`
+orig_output: Original output for `fractional_max_pool`
+out_backprop: 4-D with shape `[batch, height, width, channels]`.  Gradients
+  w.r.t. the output of `fractional_max_pool`.
+row_pooling_sequence: row pooling sequence, form pooling region with
+  col_pooling_sequence.
+col_pooling_sequence: column pooling sequence, form pooling region with
+  row_pooling sequence.
+overlapping: When set to True, it means when pooling, the values at the boundary
+  of adjacent pooling cells are used by both cells. For example:
+
+  `index  0  1  2  3  4`
+
+  `value  20 5  16 3  7`
+
+  If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+  The result would be [20, 16] for fractional max pooling.
+output: 4-D.  Gradients w.r.t. the input of `fractional_max_pool`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("FractionalAvgPool")
+    .Input("value: T")
+    .Output("output: T")
+    .Output("row_pooling_sequence: int64")
+    .Output("col_pooling_sequence: int64")
+    .Attr("pooling_ratio: list(float) >=4")
+    .Attr("pseudo_random: bool = false")
+    .Attr("overlapping: bool = false")
+    .Attr("deterministic: bool = false")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
+    .Attr("T: {float, double, int32, int64}")
+    .Doc(R"doc(
+Performs fractional average pooling on the input.
+
+Fractional average pooling is similar to Fractional max pooling in the pooling
+region generation step. The only difference is that after pooling regions are
+generated, a mean operation is performed instead of a max operation in each
+pooling region.
+
+value: 4-D with shape `[batch, height, width, channels]`.
+pooling_ratio: Pooling ratio for each dimension of `value`, currently only
+  supports row and col dimension and should be >= 1.0. For example, a valid
+  pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+  must be 1.0 because we don't allow pooling on batch and channels
+  dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+  respectively.
+pseudo_random: When set to True, generates the pooling sequence in a
+  pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+  Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+  difference between pseudorandom and random.
+overlapping: When set to True, it means when pooling, the values at the boundary
+  of adjacent pooling cells are used by both cells. For example:
+
+  `index  0  1  2  3  4`
+
+  `value  20 5  16 3  7`
+
+  If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+  The result would be [41/3, 26/3] for fractional avg pooling.
+deterministic: When set to True, a fixed pooling region will be used when
+  iterating over a FractionalAvgPool node in the computation graph. Mainly used
+  in unit test to make FractionalAvgPool deterministic.
+seed: If either seed or seed2 are set to be non-zero, the random number
+  generator is seeded by the given seed.  Otherwise, it is seeded by a
+  random seed.
+seed2: An second seed to avoid seed collision.
+output: output tensor after fractional avg pooling.
+row_pooling_sequence: row pooling sequence, needed to calculate gradient.
+col_pooling_sequence: column pooling sequence, needed to calculate gradient.
+)doc");
+
+REGISTER_OP("FractionalAvgPoolGrad")
+    .Input("orig_input_tensor_shape: int64")
+    .Input("out_backprop: T")
+    .Input("row_pooling_sequence: int64")
+    .Input("col_pooling_sequence: int64")
+    .Output("output: T")
+    .Attr("overlapping: bool = false")
+    .Attr("T: {float, double, int32, int64}")
+    .Doc(R"doc(
+Computes gradient of the FractionalAvgPool function.
+
+Unlike FractionalMaxPoolGrad, we don't need to find arg_max for
+FractionalAvgPoolGrad, we just need to evenly back-propagate each element of
+out_backprop to those indices that form the same pooling cell. Therefore, we
+just need to know the shape of original input tensor, instead of the whole
+tensor.
+
+orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool`
+out_backprop: 4-D with shape `[batch, height, width, channels]`.  Gradients
+  w.r.t. the output of `fractional_avg_pool`.
+row_pooling_sequence: row pooling sequence, form pooling region with
+  col_pooling_sequence.
+col_pooling_sequence: column pooling sequence, form pooling region with
+  row_pooling sequence.
+overlapping: When set to True, it means when pooling, the values at the boundary
+  of adjacent pooling cells are used by both cells. For example:
+
+  `index  0  1  2  3  4`
+
+  `value  20 5  16 3  7`
+
+  If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+  The result would be [41/3, 26/3] for fractional avg pooling.
+output: 4-D.  Gradients w.r.t. the input of `fractional_avg_pool`.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md
new file mode 100644
index 00000000000..ef10897212f
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fractional_max_pool.md
@@ -0,0 +1,82 @@
+### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
+
+Performs fractional max pooling on the input.
+
+Fractional max pooling is slightly different than regular max pooling.  In
+regular max pooling, you downsize an input set by taking the maximum value of
+smaller N x N subsections of the set (often 2x2), and try to reduce the set by
+a factor of N, where N is an integer.  Fractional max pooling, as you might
+expect from the word "fractional", means that the overall reduction ratio N
+does not have to be an integer.
+
+The sizes of the pooling regions are generated randomly but are fairly uniform.
+For example, let's look at the height dimension, and the constraints on the
+list of rows that will be pool boundaries.
+
+First we define the following:
+
+1.  input_row_length : the number of rows from the input set
+2.  output_row_length : which will be smaller than the input
+3.  alpha = input_row_length / output_row_length : our reduction ratio
+4.  K = floor(alpha)
+5.  row_pooling_sequence : this is the result list of pool boundary rows
+
+Then, row_pooling_sequence should satisfy:
+
+1.  a[0] = 0 : the first value of the sequence is 0
+2.  a[end] = input_row_length : the last value of the sequence is the size
+3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
+4.  length(row_pooling_sequence) = output_row_length+1
+
+For more details on fractional max pooling, see this paper:
+[Benjamin Graham, Fractional Max-Pooling]
+(http://arxiv.org/abs/1412.6071)
+
+##### Args:
+
+
+*  <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+    4-D with shape `[batch, height, width, channels]`.
+*  <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
+    Pooling ratio for each dimension of `value`, currently only
+    supports row and col dimension and should be >= 1.0. For example, a valid
+    pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+    must be 1.0 because we don't allow pooling on batch and channels
+    dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+    respectively.
+*  <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, generates the pooling sequence in a
+    pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+    Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+    difference between pseudorandom and random.
+*  <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, it means when pooling, the values at the boundary
+    of adjacent pooling cells are used by both cells. For example:
+
+    `index  0  1  2  3  4`
+
+    `value  20 5  16 3  7`
+
+    If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+    The result would be [20, 16] for fractional max pooling.
+
+*  <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, a fixed pooling region will be used when
+    iterating over a FractionalMaxPool node in the computation graph. Mainly used
+    in unit test to make FractionalMaxPool deterministic.
+*  <b>`seed`</b>: An optional `int`. Defaults to `0`.
+    If either seed or seed2 are set to be non-zero, the random number
+    generator is seeded by the given seed.  Otherwise, it is seeded by a
+    random seed.
+*  <b>`seed2`</b>: An optional `int`. Defaults to `0`.
+    An second seed to avoid seed collision.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
+
+*  <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
+*  <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
+*  <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md
new file mode 100644
index 00000000000..595e6649739
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.fractional_avg_pool.md
@@ -0,0 +1,57 @@
+### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
+
+Performs fractional average pooling on the input.
+
+Fractional average pooling is similar to Fractional max pooling in the pooling
+region generation step. The only difference is that after pooling regions are
+generated, a mean operation is performed instead of a max operation in each
+pooling region.
+
+##### Args:
+
+
+*  <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+    4-D with shape `[batch, height, width, channels]`.
+*  <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
+    Pooling ratio for each dimension of `value`, currently only
+    supports row and col dimension and should be >= 1.0. For example, a valid
+    pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+    must be 1.0 because we don't allow pooling on batch and channels
+    dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+    respectively.
+*  <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, generates the pooling sequence in a
+    pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+    Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+    difference between pseudorandom and random.
+*  <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, it means when pooling, the values at the boundary
+    of adjacent pooling cells are used by both cells. For example:
+
+    `index  0  1  2  3  4`
+
+    `value  20 5  16 3  7`
+
+    If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+    The result would be [41/3, 26/3] for fractional avg pooling.
+
+*  <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, a fixed pooling region will be used when
+    iterating over a FractionalAvgPool node in the computation graph. Mainly used
+    in unit test to make FractionalAvgPool deterministic.
+*  <b>`seed`</b>: An optional `int`. Defaults to `0`.
+    If either seed or seed2 are set to be non-zero, the random number
+    generator is seeded by the given seed.  Otherwise, it is seeded by a
+    random seed.
+*  <b>`seed2`</b>: An optional `int`. Defaults to `0`.
+    An second seed to avoid seed collision.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
+
+*  <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
+*  <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
+*  <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
+
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 4daa212c33e..9bb9236ca21 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -473,6 +473,8 @@
   * [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse)
   * [`erosion2d`](../../api_docs/python/nn.md#erosion2d)
   * [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
+  * [`fractional_avg_pool`](../../api_docs/python/nn.md#fractional_avg_pool)
+  * [`fractional_max_pool`](../../api_docs/python/nn.md#fractional_max_pool)
   * [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
   * [`l2_loss`](../../api_docs/python/nn.md#l2_loss)
   * [`l2_normalize`](../../api_docs/python/nn.md#l2_normalize)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 5bf8b5e6c79..b87f176e2bc 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -828,6 +828,151 @@ Performs 3D max pooling on the input.
   A `Tensor`. Has the same type as `input`. The max pooled output tensor.
 
 
+- - -
+
+### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
+
+Performs fractional average pooling on the input.
+
+Fractional average pooling is similar to Fractional max pooling in the pooling
+region generation step. The only difference is that after pooling regions are
+generated, a mean operation is performed instead of a max operation in each
+pooling region.
+
+##### Args:
+
+
+*  <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+    4-D with shape `[batch, height, width, channels]`.
+*  <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
+    Pooling ratio for each dimension of `value`, currently only
+    supports row and col dimension and should be >= 1.0. For example, a valid
+    pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+    must be 1.0 because we don't allow pooling on batch and channels
+    dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+    respectively.
+*  <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, generates the pooling sequence in a
+    pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+    Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+    difference between pseudorandom and random.
+*  <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, it means when pooling, the values at the boundary
+    of adjacent pooling cells are used by both cells. For example:
+
+    `index  0  1  2  3  4`
+
+    `value  20 5  16 3  7`
+
+    If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+    The result would be [41/3, 26/3] for fractional avg pooling.
+
+*  <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, a fixed pooling region will be used when
+    iterating over a FractionalAvgPool node in the computation graph. Mainly used
+    in unit test to make FractionalAvgPool deterministic.
+*  <b>`seed`</b>: An optional `int`. Defaults to `0`.
+    If either seed or seed2 are set to be non-zero, the random number
+    generator is seeded by the given seed.  Otherwise, it is seeded by a
+    random seed.
+*  <b>`seed2`</b>: An optional `int`. Defaults to `0`.
+    An second seed to avoid seed collision.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
+
+*  <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
+*  <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
+*  <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
+
+
+- - -
+
+### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
+
+Performs fractional max pooling on the input.
+
+Fractional max pooling is slightly different than regular max pooling.  In
+regular max pooling, you downsize an input set by taking the maximum value of
+smaller N x N subsections of the set (often 2x2), and try to reduce the set by
+a factor of N, where N is an integer.  Fractional max pooling, as you might
+expect from the word "fractional", means that the overall reduction ratio N
+does not have to be an integer.
+
+The sizes of the pooling regions are generated randomly but are fairly uniform.
+For example, let's look at the height dimension, and the constraints on the
+list of rows that will be pool boundaries.
+
+First we define the following:
+
+1.  input_row_length : the number of rows from the input set
+2.  output_row_length : which will be smaller than the input
+3.  alpha = input_row_length / output_row_length : our reduction ratio
+4.  K = floor(alpha)
+5.  row_pooling_sequence : this is the result list of pool boundary rows
+
+Then, row_pooling_sequence should satisfy:
+
+1.  a[0] = 0 : the first value of the sequence is 0
+2.  a[end] = input_row_length : the last value of the sequence is the size
+3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
+4.  length(row_pooling_sequence) = output_row_length+1
+
+For more details on fractional max pooling, see this paper:
+[Benjamin Graham, Fractional Max-Pooling]
+(http://arxiv.org/abs/1412.6071)
+
+##### Args:
+
+
+*  <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
+    4-D with shape `[batch, height, width, channels]`.
+*  <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
+    Pooling ratio for each dimension of `value`, currently only
+    supports row and col dimension and should be >= 1.0. For example, a valid
+    pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
+    must be 1.0 because we don't allow pooling on batch and channels
+    dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
+    respectively.
+*  <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, generates the pooling sequence in a
+    pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
+    Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
+    difference between pseudorandom and random.
+*  <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, it means when pooling, the values at the boundary
+    of adjacent pooling cells are used by both cells. For example:
+
+    `index  0  1  2  3  4`
+
+    `value  20 5  16 3  7`
+
+    If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
+    The result would be [20, 16] for fractional max pooling.
+
+*  <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
+    When set to True, a fixed pooling region will be used when
+    iterating over a FractionalMaxPool node in the computation graph. Mainly used
+    in unit test to make FractionalMaxPool deterministic.
+*  <b>`seed`</b>: An optional `int`. Defaults to `0`.
+    If either seed or seed2 are set to be non-zero, the random number
+    generator is seeded by the given seed.  Otherwise, it is seeded by a
+    random seed.
+*  <b>`seed2`</b>: An optional `int`. Defaults to `0`.
+    An second seed to avoid seed collision.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
+
+*  <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
+*  <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
+*  <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
+
+
 
 ## Morphological filtering
 
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index effe925df06..f1e5b042ef7 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -36,6 +36,8 @@ py_tests(
         "determinant_op_test.py",
         "edit_distance_op_test.py",
         "fifo_queue_test.py",
+        "fractional_avg_pool_op_test.py",
+        "fractional_max_pool_op_test.py",
         "identity_op_py_test.py",
         "in_topk_op_test.py",
         "io_ops_test.py",
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
new file mode 100644
index 00000000000..dafecf27288
--- /dev/null
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -0,0 +1,521 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fractional average pool operation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_nn_ops
+
+
+class FractionalAvgTest(tf.test.TestCase):
+
+  # Random number generate with seed.
+  _PRNG = np.random.RandomState(341261000)
+  _SEED = 341261001
+  _SEED2 = 341261002
+
+  def _AvgPoolAlongRows(self, input_matrix, row_seq, overlapping):
+    """Perform average pool along row of a 2-D matrix based on row_seq.
+
+    Args:
+      input_matrix: A 2-D matrix.
+      row_seq: Cumulative pooling sequence along row.
+      overlapping: Whether or not use overlapping when pooling.
+
+    Returns:
+      A 2-D matrix, with
+        * num_rows = len(row_seq)-1
+        * num_cols = input_matrix.num_cols.
+    """
+    output_image = np.zeros(input_matrix.shape[1])
+    row_max = row_seq[-1]
+    for i in range(row_seq.shape[0] - 1):
+      row_start = row_seq[i]
+      row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
+      row_end = min(row_end, row_max)
+      output_image = np.vstack((output_image,
+                                np.mean(input_matrix[row_start:row_end, :],
+                                        axis=0)))  # axis 0 is along row
+    # remove the sentinel row
+    return output_image[1:, :]
+
+  def _AvgPoolAlongCols(self, input_matrix, col_seq, overlapping):
+    """Perform average pool along column of a 2-D matrix based on col_seq.
+
+    Args:
+      input_matrix: A 2-D matrix.
+      col_seq: Cumulative pooling sequence along column.
+      overlapping: Whether or not use overlapping when pooling.
+
+    Returns:
+      A 2-D matrix, with
+        * num_rows = input_matrix.num_rows
+        * num_cols = len(col_seq)-1.
+    """
+    input_matrix = input_matrix.transpose()
+    output_matrix = self._AvgPoolAlongRows(input_matrix, col_seq, overlapping)
+    return output_matrix.transpose()
+
+  def _GetExpectedFractionalAvgPoolResult(self, input_tensor, row_seq, col_seq,
+                                          overlapping):
+    """Get expected fractional average pooling result.
+
+    row_seq and col_seq together defines the fractional pooling region.
+
+    Args:
+      input_tensor: Original input tensor, assuming it is a 4-D tensor, with
+        dimension as [batch, height/row, width/column, channels/depth].
+      row_seq: Cumulative pooling sequence along row.
+      col_seq: Cumulative pooling sequence along column.
+      overlapping: Use overlapping when doing pooling.
+
+    Returns:
+      A 4-D tensor that is the result of average pooling on input_tensor based
+        on pooling region defined by row_seq and col_seq, conditioned on whether
+        or not overlapping is used.
+    """
+    input_shape = input_tensor.shape
+    output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
+                    input_shape[3])
+    output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
+    for batch in range(input_shape[0]):
+      for channel in range(input_shape[3]):
+        two_dim_slice = input_tensor[batch, :, :, channel]
+        tmp = self._AvgPoolAlongRows(two_dim_slice, row_seq, overlapping)
+        output_tensor[batch, :, :, channel] = self._AvgPoolAlongCols(
+            tmp, col_seq, overlapping)
+
+    return output_tensor
+
+  def _ValidateFractionalAvgPoolResult(self, input_tensor, pooling_ratio,
+                                       pseudo_random, overlapping):
+    """Validate FractionalAvgPool's result against expected.
+
+    Expected result is computed given input_tensor, and pooling region defined
+    by row_seq and col_seq.
+
+    Args:
+      input_tensor: A tensor or numpy ndarray.
+      pooling_ratio: A list or tuple of length 4, first and last element be 1.
+      pseudo_random: Use pseudo random method to generate pooling sequence.
+      overlapping: Use overlapping when pooling.
+
+    Returns:
+      None
+    """
+    with self.test_session() as sess:
+      p, r, c = tf.nn.fractional_avg_pool(input_tensor,
+                                          pooling_ratio,
+                                          pseudo_random,
+                                          overlapping,
+                                          deterministic=True,
+                                          seed=self._SEED,
+                                          seed2=self._SEED2)
+      actual, row_seq, col_seq = sess.run([p, r, c])
+      expected = self._GetExpectedFractionalAvgPoolResult(input_tensor, row_seq,
+                                                          col_seq, overlapping)
+      self.assertShapeEqual(expected, p)
+      self.assertAllClose(expected, actual)
+
+  def _testVisually(self):
+    """Manual test by printing out intermediate result of a small random tensor.
+
+    Since _GetExpectedFractionalAvgPoolResult is 'automated', it feels safer to
+    have a test case that you can see what's happening.
+    This test will generate a small, random, int 2D matrix, and feed it to
+    FractionalAvgPool and _GetExpectedFractionalAvgPoolResult.
+    """
+    num_rows = 6
+    num_cols = 6
+    tensor_shape = (1, num_rows, num_cols, 1)
+    pseudo_random = False
+    for overlapping in True, False:
+      print("-" * 70)
+      print("Testing FractionalAvgPool with overlapping = {}".format(
+          overlapping))
+      rand_mat = self._PRNG.randint(10, size=tensor_shape)
+      pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
+      with self.test_session() as sess:
+        p, r, c = tf.nn.fractional_avg_pool(
+            rand_mat.astype(np.float32),
+            pooling_ratio,
+            pseudo_random,
+            overlapping,
+            deterministic=True,
+            seed=self._SEED,
+            seed2=self._SEED2)
+        tensor_output, row_seq, col_seq = sess.run([p, r, c])
+        expected_result = self._GetExpectedFractionalAvgPoolResult(
+            rand_mat.astype(np.float32), row_seq, col_seq, overlapping)
+        print("row sequence:")
+        print(row_seq)
+        print("column sequence:")
+        print(col_seq)
+
+        print("Input:")
+        # Print input with pooling region marked.
+        for i in range(num_rows):
+          row_to_print = []
+          for j in range(num_cols):
+            if j in col_seq:
+              row_to_print.append("|")
+            row_to_print.append(str(rand_mat[0, i, j, 0]))
+          row_to_print.append("|")
+          if i in row_seq:
+            print("-" * 2 * len(row_to_print))
+          print(" ".join(row_to_print))
+        print("-" * 2 * len(row_to_print))
+
+        print("Output from FractionalAvgPool:")
+        print(tensor_output[0, :, :, 0])
+        print("Expected result:")
+        print(expected_result[0, :, :, 0])
+
+  def testAllInputOptions(self):
+    """Try all possible input options for fractional_avg_pool.
+    """
+    num_batches = 5
+    num_channels = 3
+    num_rows = 20
+    num_cols = 30
+    for pseudo_random in True, False:
+      for overlapping in True, False:
+        tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+        # random tensor with value in [-500.0, 500.0)
+        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+        self._ValidateFractionalAvgPoolResult(
+            rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
+            overlapping)
+
+  def testIntegerTensorInput(self):
+    """Test FractionalAvgPool works fine when input tensor is integer type.
+
+    I would have used _ValidateFractionalAvgPoolResult function to automate this
+    process, however, there's rounding issue. It is caused by numpy.mean cast
+    integer input to numpy.float64 for intermediate use. While for
+    fractional_avg_pool, the mean operation is integer division (trucated).  So,
+    for this test case, I will hard code a simple matrix.
+    """
+    pseudo_random = True
+    overlapping = True
+    tensor_shape = (1, 6, 6, 1)
+    # pyformat: disable
+    mat = np.array([
+        [2, 6, 4, 1, 3, 6],
+        [8, 9, 1, 6, 6, 8],
+        [3, 9, 8, 2, 5, 6],
+        [2, 7, 9, 5, 4, 5],
+        [8, 5, 0, 5, 7, 4],
+        [4, 4, 5, 9, 7, 2]
+    ])
+    # pyformat: enable
+    with self.test_session() as sess:
+      # Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
+      # are the same each time. We can have an expected result precomputed.
+      # r = [0, 2, 4, 6]
+      # c = [0, 1, 3, 4, 6]
+
+      # pyformat: disable
+      expected = np.array([
+          [6, 5, 3, 5],
+          [5, 5, 4, 5],
+          [5, 4, 7, 5]
+      ]).reshape((1, 3, 4, 1))
+      # pyformat: enable
+      p, unused_r, unused_c = tf.nn.fractional_avg_pool(
+          mat.reshape(tensor_shape), [1, math.sqrt(3), math.sqrt(2), 1],
+          pseudo_random,
+          overlapping,
+          deterministic=True,
+          seed=self._SEED,
+          seed2=self._SEED2)
+      actual = sess.run(p)
+      self.assertShapeEqual(expected, p)
+      self.assertAllClose(expected, actual)
+
+  def testDifferentTensorShapes(self):
+    """Test different shapes of input tensor.
+
+    Mainly test different combinations of num_rows and num_cols.
+    """
+    pseudo_random = True
+    overlapping = True
+    for num_batches in [1, 3]:
+      for num_channels in [1, 3]:
+        for num_rows in [10, 20, 50]:
+          for num_cols in [10, 20, 50]:
+            tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+            # random tensor with value in [-500.0, 500.0)
+            rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+            self._ValidateFractionalAvgPoolResult(
+                rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
+                overlapping)
+
+  def testLargePoolingRatio(self):
+    """Test when pooling ratio is not within [1, 2).
+    """
+    pseudo_random = True
+    overlapping = True
+    num_batches = 3
+    num_channels = 3
+    num_rows = 30
+    num_cols = 50
+    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+    for row_ratio in [math.sqrt(11), math.sqrt(37)]:
+      for col_ratio in [math.sqrt(11), math.sqrt(27)]:
+        # random tensor with value in [-500.0, 500.0)
+        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+        self._ValidateFractionalAvgPoolResult(rand_mat,
+                                              [1, row_ratio, col_ratio, 1],
+                                              pseudo_random, overlapping)
+
+  def testDivisiblePoolingRatio(self):
+    """Test when num of rows/cols can evenly divide pooling ratio.
+
+    This is a case regular average pooling can handle. Should be handled by
+    fractional pooling as well.
+    """
+    pseudo_random = True
+    overlapping = True
+    num_batches = 3
+    num_channels = 3
+    num_rows = 30
+    num_cols = 50
+    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+    # random tensor with value in [-500.0, 500.0)
+    rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+    self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
+                                          overlapping)
+
+
+class FractionalAvgPoolGradTest(tf.test.TestCase):
+  """Tests for FractionalAvgPoolGrad.
+
+  Two types of tests for FractionalAvgPoolGrad.
+  1) Test fractional_avg_pool_grad() directly.
+    This type of test relies on gen_nn_ops._avg_pool_grad() returns the
+  correct result. For example:
+    * input_tensor_shape = (1, 10, 10, 1)
+    * window_size = (1, 2, 2, 1)
+    * stride_size = (1, 2, 2, 1)
+    * padding: not really important, since 10/2 is divisible
+  avg pooling should generate the same result as fractional avg pooling with:
+    * row_sequence = [0, 2, 4, 6, 8, 10]
+    * col_sequence = [0, 2, 4, 6, 8, 10]
+    * overlapping = False
+  This also means their gradients in such case will be the same.
+
+  Similarly, when
+    * input_tensor_shape = (1, 7, 7, 1)
+    * window_size = (1, 3, 3, 1)
+    * stride_size = (1, 2, 2, 1)
+    * padding: not important
+  avg pooling should generate the same result as fractional avg pooling with:
+    * row_sequence = [0, 2, 4, 7]
+    * col_sequence = [0, 2, 4, 7]
+    * overlapping = True
+  2) Test through compute_gradient_error()
+  """
+  _PRNG = np.random.RandomState(341261004)
+  _SEED = 341261005
+  _SEED2 = 341261006
+
+  def _GenerateRandomInputTensor(self, shape):
+    num_elements = 1
+    for dim_size in shape:
+      num_elements *= dim_size
+    x = self._PRNG.rand(num_elements) * 1000
+    return x.reshape(shape)
+
+  def testDirectNotUseOverlapping(self):
+    for num_batches in [1, 3]:
+      for row_window_size in [2, 5]:
+        for col_window_size in [2, 4]:
+          num_rows = row_window_size * 5
+          num_cols = col_window_size * 7
+          for num_channels in [1, 2]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            with self.test_session() as _:
+              input_tensor = tf.constant(self._GenerateRandomInputTensor(
+                  input_shape).astype(np.float32))
+              window_size = [1, row_window_size, col_window_size, 1]
+              stride_size = [1, row_window_size, col_window_size, 1]
+              padding = "VALID"
+              output_tensor = tf.nn.avg_pool(input_tensor, window_size,
+                                             stride_size, padding)
+              output_data = output_tensor.eval()
+              num_elements = 1
+              for dim_size in output_data.shape:
+                num_elements *= dim_size
+              output_backprop = (self._PRNG.rand(num_elements) *
+                                 1000).reshape(output_data.shape)
+              input_backprop_tensor = gen_nn_ops._avg_pool_grad(
+                  input_tensor.get_shape(), output_backprop, window_size,
+                  stride_size, padding)
+              input_backprop = input_backprop_tensor.eval()
+              row_seq = list(range(0, num_rows + 1, row_window_size))
+              col_seq = list(range(0, num_cols + 1, col_window_size))
+              fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
+                  input_tensor.get_shape(),
+                  output_backprop,
+                  row_seq,
+                  col_seq,
+                  overlapping=False)
+              fap_input_backprop = fap_input_backprop_tensor.eval()
+              self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
+              self.assertAllClose(input_backprop, fap_input_backprop)
+
+  def testDirectUseOverlapping(self):
+    for num_batches in [1, 3]:
+      for row_window_size in [2, 5]:
+        for col_window_size in [2, 4]:
+          num_rows = (row_window_size - 1) * 5 + 1
+          num_cols = (col_window_size - 1) * 7 + 1
+          for num_channels in [1, 2]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            with self.test_session() as _:
+              input_tensor = tf.constant(self._GenerateRandomInputTensor(
+                  input_shape).astype(np.float32))
+              window_size = [1, row_window_size, col_window_size, 1]
+              stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
+              padding = "VALID"
+              output_tensor = tf.nn.avg_pool(input_tensor, window_size,
+                                             stride_size, padding)
+              output_data = output_tensor.eval()
+              num_elements = 1
+              for dim_size in output_data.shape:
+                num_elements *= dim_size
+              output_backprop = (self._PRNG.rand(num_elements) *
+                                 1000).reshape(output_data.shape)
+              input_backprop_tensor = gen_nn_ops._avg_pool_grad(
+                  input_tensor.get_shape(), output_backprop, window_size,
+                  stride_size, padding)
+              input_backprop = input_backprop_tensor.eval()
+              row_seq = list(range(0, num_rows, row_window_size - 1))
+              col_seq = list(range(0, num_cols, col_window_size - 1))
+              row_seq[-1] += 1
+              col_seq[-1] += 1
+              fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
+                  input_tensor.get_shape(),
+                  output_backprop,
+                  row_seq,
+                  col_seq,
+                  overlapping=True)
+              fap_input_backprop = fap_input_backprop_tensor.eval()
+              self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
+              self.assertAllClose(input_backprop, fap_input_backprop)
+
+  def testAllInputOptionsThroughGradientError(self):
+    input_shape = (1, 7, 13, 1)
+    input_data = self._GenerateRandomInputTensor(input_shape)
+    pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
+
+    for pseudo_random in True, False:
+      for overlapping in True, False:
+        with self.test_session() as _:
+          input_tensor = tf.constant(input_data, shape=input_shape)
+          output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
+              input_tensor,
+              pooling_ratio,
+              pseudo_random=pseudo_random,
+              overlapping=overlapping,
+              deterministic=True,
+              seed=self._SEED,
+              seed2=self._SEED2)
+          output_data = output_tensor.eval()
+          output_shape = output_data.shape
+          # error_margin and delta setting is similar to avg_pool_grad.
+          error_margin = 1e-4
+          gradient_error = tf.test.compute_gradient_error(
+              input_tensor,
+              input_shape,
+              output_tensor,
+              output_shape,
+              x_init_value=input_data.reshape(input_shape),
+              delta=1e-2)
+          self.assertLess(gradient_error, error_margin)
+
+  def testDifferentTensorShapesThroughGradientError(self):
+    pseudo_random = True
+    overlapping = True
+    pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
+    for num_batches in [1, 2]:
+      for num_rows in [5, 13]:
+        for num_cols in [5, 11]:
+          for num_channels in [1, 3]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            input_data = self._GenerateRandomInputTensor(input_shape)
+            with self.test_session() as _:
+              input_tensor = tf.constant(input_data, shape=input_shape)
+              output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
+                  input_tensor,
+                  pooling_ratio,
+                  pseudo_random=pseudo_random,
+                  overlapping=overlapping,
+                  deterministic=True,
+                  seed=self._SEED,
+                  seed2=self._SEED2)
+              output_data = output_tensor.eval()
+              output_shape = output_data.shape
+              # error_margin and delta setting is similar to avg_pool_grad.
+              error_margin = 1e-4
+              gradient_error = tf.test.compute_gradient_error(
+                  input_tensor,
+                  input_shape,
+                  output_tensor,
+                  output_shape,
+                  x_init_value=input_data.reshape(input_shape),
+                  delta=1e-2)
+              self.assertLess(gradient_error, error_margin)
+
+  def testLargePoolingRatioThroughGradientError(self):
+    input_shape = (1, 17, 23, 1)
+    input_data = self._GenerateRandomInputTensor(input_shape)
+    pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
+    output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
+    overlapping = True
+    pseudo_random = False
+
+    with self.test_session() as _:
+      input_tensor = tf.constant(input_data, shape=input_shape)
+      output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
+          input_tensor,
+          pooling_ratio,
+          pseudo_random=pseudo_random,
+          overlapping=overlapping,
+          deterministic=True,
+          seed=self._SEED,
+          seed2=self._SEED2)
+      # error_margin and delta setting is similar to avg_pool_grad.
+      error_margin = 1e-4
+      gradient_error = tf.test.compute_gradient_error(
+          input_tensor,
+          input_shape,
+          output_tensor,
+          output_shape,
+          x_init_value=input_data.reshape(input_shape),
+          delta=1e-2)
+      self.assertLess(gradient_error, error_margin)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
new file mode 100644
index 00000000000..424f4d588ab
--- /dev/null
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -0,0 +1,582 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for fractional max pool operation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_nn_ops
+
+
+class FractionalMaxPoolTest(tf.test.TestCase):
+
+  # Random number generate with seed.
+  _PRNG = np.random.RandomState(341261)
+  _SEED = 123456
+  _SEED2 = 654321
+
+  def _MaxPoolAlongRows(self, input_matrix, row_seq, overlapping):
+    """Perform max pool along row of a 2-D matrix based on row_seq.
+
+    Args:
+      input_matrix: A 2-D matrix.
+      row_seq: Cumulative pooling sequence along row.
+      overlapping: Whether or not use overlapping when pooling.
+
+    Returns:
+      A 2-D matrix, with
+        * num_rows = len(row_seq)-1
+        * num_cols = input_matrix.num_cols.
+    """
+    output_image = np.zeros(input_matrix.shape[1])
+    row_max = row_seq[-1]
+    for i in range(row_seq.shape[0] - 1):
+      row_start = row_seq[i]
+      row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
+      row_end = min(row_end, row_max)
+      output_image = np.vstack((output_image,
+                                np.amax(input_matrix[row_start:row_end, :],
+                                        axis=0)))  # axis 0 is along row
+    # remove the sentinel row
+    return output_image[1:, :]
+
+  def _MaxPoolAlongCols(self, input_matrix, col_seq, overlapping):
+    """Perform max pool along column of a 2-D matrix based on col_seq.
+
+    Args:
+      input_matrix: A 2-D matrix.
+      col_seq: Cumulative pooling sequence along column.
+      overlapping: Whether or not use overlapping when pooling.
+
+    Returns:
+      A 2-D matrix, with
+        * num_rows = input_matrix.num_rows
+        * num_cols = len(col_seq)-1.
+    """
+    input_matrix = input_matrix.transpose()
+    output_matrix = self._MaxPoolAlongRows(input_matrix, col_seq, overlapping)
+    return output_matrix.transpose()
+
+  def _GetExpectedFractionalMaxPoolResult(self, input_tensor, row_seq, col_seq,
+                                          overlapping):
+    """Get expected fractional max pool result.
+
+    row_seq and col_seq together defines the fractional pooling region.
+
+    Args:
+      input_tensor: Original input tensor, assuming it is a 4-D tensor, with
+        dimension as [batch, height/row, width/column, channels/depth].
+      row_seq: Cumulative pooling sequence along row.
+      col_seq: Cumulative pooling sequence along column.
+      overlapping: Use overlapping when doing pooling.
+
+    Returns:
+      A 4-D tensor that is the result of max pooling on input_tensor based on
+        pooling region defined by row_seq and col_seq, conditioned on whether or
+        not overlapping is used.
+    """
+    input_shape = input_tensor.shape
+    output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
+                    input_shape[3])
+    output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
+    for batch in range(input_shape[0]):
+      for channel in range(input_shape[3]):
+        two_dim_slice = input_tensor[batch, :, :, channel]
+        tmp = self._MaxPoolAlongRows(two_dim_slice, row_seq, overlapping)
+        output_tensor[batch, :, :, channel] = self._MaxPoolAlongCols(
+            tmp, col_seq, overlapping)
+
+    return output_tensor
+
+  def _ValidateFractionalMaxPoolResult(self, input_tensor, pooling_ratio,
+                                       pseudo_random, overlapping):
+    """Validate FractionalMaxPool's result against expected.
+
+    Expected result is computed given input_tensor, and pooling region defined
+    by row_seq and col_seq.
+
+    Args:
+      input_tensor: A tensor or numpy ndarray.
+      pooling_ratio: A list or tuple of length 4, first and last element be 1.
+      pseudo_random: Use pseudo random method to generate pooling sequence.
+      overlapping: Use overlapping when pooling.
+
+    Returns:
+      None
+    """
+    with self.test_session() as sess:
+      p, r, c = tf.nn.fractional_max_pool(input_tensor,
+                                          pooling_ratio,
+                                          pseudo_random,
+                                          overlapping,
+                                          deterministic=True,
+                                          seed=self._SEED,
+                                          seed2=self._SEED2)
+      actual, row_seq, col_seq = sess.run([p, r, c])
+      expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq,
+                                                          col_seq, overlapping)
+      self.assertShapeEqual(expected, p)
+      self.assertAllClose(expected, actual)
+
+  def _testVisually(self):
+    """Manual test by printing out intermediate result of a small random tensor.
+
+    Since _GetExpectedFractionalMaxPoolResult is 'automated', it feel safer to
+    have a test case that you can see what's happening.
+    This test will generate a small, random, int 2D matrix, and feed it to
+    FractinalMaxPool and _GetExpectedFractionalMaxPoolResult.
+    """
+    num_rows = 6
+    num_cols = 6
+    tensor_shape = (1, num_rows, num_cols, 1)
+    pseudo_random = False
+    for overlapping in True, False:
+      print("-" * 70)
+      print("Testing FractionalMaxPool with overlapping = {}".format(
+          overlapping))
+      rand_mat = self._PRNG.randint(10, size=tensor_shape)
+      pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
+      with self.test_session() as sess:
+        p, r, c = tf.nn.fractional_max_pool(rand_mat,
+                                            pooling_ratio,
+                                            pseudo_random,
+                                            overlapping,
+                                            deterministic=True,
+                                            seed=self._SEED,
+                                            seed2=self._SEED2)
+        tensor_output, row_seq, col_seq = sess.run([p, r, c])
+        expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat,
+                                                                   row_seq,
+                                                                   col_seq,
+                                                                   overlapping)
+        print("row sequence:")
+        print(row_seq)
+        print("column sequence:")
+        print(col_seq)
+
+        print("Input:")
+        # Print input with pooling region marked.
+        for i in range(num_rows):
+          row_to_print = []
+          for j in range(num_cols):
+            if j in col_seq:
+              row_to_print.append("|")
+            row_to_print.append(str(rand_mat[0, i, j, 0]))
+          row_to_print.append("|")
+          if i in row_seq:
+            print("-" * 2 * len(row_to_print))
+          print(" ".join(row_to_print))
+        print("-" * 2 * len(row_to_print))
+
+        print("Output from FractionalMaxPool:")
+        print(tensor_output[0, :, :, 0])
+        print("Expected result:")
+        print(expected_result[0, :, :, 0])
+
+  def testAllInputOptions(self):
+    """Try all possible input options for fractional_max_pool.
+    """
+    num_batches = 5
+    num_channels = 3
+    num_rows = 20
+    num_cols = 30
+    for pseudo_random in True, False:
+      for overlapping in True, False:
+        tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+        # random tensor with value in [-500.0, 500.0)
+        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+        self._ValidateFractionalMaxPoolResult(
+            rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
+            overlapping)
+
+  def testIntegerTensorInput(self):
+    """Test it works fine when input tensor is integer type.
+    """
+    num_batches = 5
+    num_channels = 3
+    num_rows = 20
+    num_cols = 30
+    pseudo_random = True
+    overlapping = True
+    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+    rand_mat = self._PRNG.randint(1000, size=tensor_shape)
+    self._ValidateFractionalMaxPoolResult(rand_mat,
+                                          [1, math.sqrt(3), math.sqrt(2), 1],
+                                          pseudo_random, overlapping)
+
+  def testDifferentTensorShapes(self):
+    """Test different shapes of input tensor.
+
+    Mainly test different combinations of num_rows and num_cols.
+    """
+    pseudo_random = True
+    overlapping = True
+    for num_batches in [1, 3]:
+      for num_channels in [1, 3]:
+        for num_rows in [10, 20, 50]:
+          for num_cols in [10, 20, 50]:
+            tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+            # random tensor with value in [-500.0, 500.0)
+            rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+            self._ValidateFractionalMaxPoolResult(
+                rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
+                overlapping)
+
+  def testLargePoolingRatio(self):
+    """Test when pooling ratio is not within [1, 2).
+    """
+    pseudo_random = True
+    overlapping = True
+    num_batches = 3
+    num_channels = 3
+    num_rows = 30
+    num_cols = 50
+    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+    for row_ratio in [math.sqrt(11), math.sqrt(37)]:
+      for col_ratio in [math.sqrt(11), math.sqrt(27)]:
+        # random tensor with value in [-500.0, 500.0)
+        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+        self._ValidateFractionalMaxPoolResult(rand_mat,
+                                              [1, row_ratio, col_ratio, 1],
+                                              pseudo_random, overlapping)
+
+  def testDivisiblePoolingRatio(self):
+    """Test when num of rows/cols can evenly divide pooling ratio.
+
+    This is a case regular max pooling can handle. Should be handled by
+    fractional pooling as well.
+    """
+    pseudo_random = True
+    overlapping = True
+    num_batches = 3
+    num_channels = 3
+    num_rows = 30
+    num_cols = 50
+    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
+    # random tensor with value in [-500.0, 500.0)
+    rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
+    self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
+                                          overlapping)
+
+
+class FractionalMaxPoolGradTest(tf.test.TestCase):
+  """Tests for FractionalMaxPoolGrad.
+
+  Two types of tests for FractionalMaxPoolGrad.
+  1) Test fractional_max_pool_grad() directly.
+    This type of test relies on gen_nn_ops._max_pool_grad() returns the correct
+  result. For example:
+    * input_tensor_shape = (1, 10, 10, 1)
+    * window_size = (1, 2, 2, 1)
+    * stride_size = (1, 2, 2, 1)
+    * padding: not really import, since 10/2 is divisible
+  max pooling should generate the same result as fractional max pooling with:
+    * row_sequence = [0, 2, 4, 6, 8, 10]
+    * col_sequence = [0, 2, 4, 6, 8, 10]
+    * overlapping = False
+  This also means their gradients in such case will be the same.
+
+    Similarly, when
+    * input_tensor_shape = (1, 7, 7, 1)
+    * window_size = (1, 3, 3, 1)
+    * stride_size = (1, 2, 2, 1)
+    * padding: not important
+  max pooling should generate the same result as fractional max pooling with:
+    * row_sequence = [0, 2, 4, 7]
+    * col_sequence = [0, 2, 4, 7]
+    * overlapping = True
+  2) Test through compute_gradient_error()
+  """
+
+  _PRNG = np.random.RandomState(341261)
+  _SEED = 123456
+  _SEED2 = 654321
+
+  def _GenerateUniqueRandomInputTensor(self, shape):
+    """Generate 'unqiue' random input tensor.
+
+    'Unique' means there's no collision values in the tensor, all elements are
+    different. This is done by generating sequence of integers with step of 1
+    and then randomly shuffle these integers.
+
+    Args:
+      shape: Shape of the tensor desired.
+
+    Returns:
+      A numpy ndarray with size = shape and dtype = numpy.float32.
+    """
+    num_elements = 1
+    for size in shape:
+      num_elements *= size
+    x = np.arange(num_elements, dtype=np.float32)
+    self._PRNG.shuffle(x)
+    return x.reshape(shape)
+
+  def testDirectNotUseOverlapping(self):
+    for num_batches in [1, 3]:
+      for row_window_size in [2, 5]:
+        for col_window_size in [2, 4]:
+          num_rows = row_window_size * 5
+          num_cols = col_window_size * 7
+          for num_channels in [1, 2]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            with self.test_session() as _:
+              input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
+                  input_shape))
+              window_size = [1, row_window_size, col_window_size, 1]
+              stride_size = [1, row_window_size, col_window_size, 1]
+              padding = "VALID"
+              output_tensor = tf.nn.max_pool(input_tensor, window_size,
+                                             stride_size, padding)
+              output_data = output_tensor.eval()
+              output_backprop = self._PRNG.randint(100, size=output_data.shape)
+              input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
+                                                                output_tensor,
+                                                                output_backprop,
+                                                                window_size,
+                                                                stride_size,
+                                                                padding)
+              input_backprop = input_backprop_tensor.eval()
+              row_seq = list(range(0, num_rows + 1, row_window_size))
+              col_seq = list(range(0, num_cols + 1, col_window_size))
+              fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
+                  input_tensor,
+                  output_tensor,
+                  output_backprop,
+                  row_seq,
+                  col_seq,
+                  overlapping=False)
+              fmp_input_backprop = fmp_input_backprop_tensor.eval()
+              self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
+              self.assertAllClose(input_backprop, fmp_input_backprop)
+
+  def testDirectUseOverlapping(self):
+    for num_batches in [1, 3]:
+      for row_window_size in [2, 5]:
+        for col_window_size in [2, 4]:
+          num_rows = (row_window_size - 1) * 5 + 1
+          num_cols = (col_window_size - 1) * 7 + 1
+          for num_channels in [1, 2]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            with self.test_session() as _:
+              input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
+                  input_shape))
+              window_size = [1, row_window_size, col_window_size, 1]
+              stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
+              padding = "VALID"
+              output_tensor = tf.nn.max_pool(input_tensor, window_size,
+                                             stride_size, padding)
+              output_data = output_tensor.eval()
+              output_backprop = self._PRNG.randint(100, size=output_data.shape)
+              input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
+                                                                output_tensor,
+                                                                output_backprop,
+                                                                window_size,
+                                                                stride_size,
+                                                                padding)
+              input_backprop = input_backprop_tensor.eval()
+              row_seq = list(range(0, num_rows, row_window_size - 1))
+              col_seq = list(range(0, num_cols, col_window_size - 1))
+              row_seq[-1] += 1
+              col_seq[-1] += 1
+              fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
+                  input_tensor,
+                  output_tensor,
+                  output_backprop,
+                  row_seq,
+                  col_seq,
+                  overlapping=True)
+              fmp_input_backprop = fmp_input_backprop_tensor.eval()
+              self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
+              self.assertAllClose(input_backprop, fmp_input_backprop)
+
+  def testAllInputOptionsThroughGradientError(self):
+    input_shape = (1, 7, 13, 1)
+    input_data = self._GenerateUniqueRandomInputTensor(input_shape)
+    # Add some randomness to make input_data not so 'integer'
+    input_data += self._PRNG.random_sample(input_shape)
+    pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
+
+    for pseudo_random in True, False:
+      for overlapping in True, False:
+        with self.test_session() as _:
+          input_tensor = tf.constant(input_data, shape=input_shape)
+          output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
+              input_tensor,
+              pooling_ratio,
+              pseudo_random=pseudo_random,
+              overlapping=overlapping,
+              deterministic=True,
+              seed=self._SEED,
+              seed2=self._SEED2)
+          output_data = output_tensor.eval()
+          output_shape = output_data.shape
+          # error_margin and delta setting is similar to max_pool_grad.
+          error_margin = 1e-3
+          gradient_error = tf.test.compute_gradient_error(
+              input_tensor,
+              input_shape,
+              output_tensor,
+              output_shape,
+              x_init_value=input_data.reshape(input_shape),
+              delta=1e-2)
+          self.assertLess(gradient_error, error_margin)
+
+  def testDifferentTensorShapesThroughGradientError(self):
+    pseudo_random = True
+    overlapping = True
+    pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
+    for num_batches in [1, 2]:
+      for num_rows in [5, 13]:
+        for num_cols in [5, 11]:
+          for num_channels in [1, 3]:
+            input_shape = (num_batches, num_rows, num_cols, num_channels)
+            input_data = self._GenerateUniqueRandomInputTensor(input_shape)
+            # Add some randomness to make input_data not so 'integer'
+            input_data += self._PRNG.random_sample(input_shape)
+            with self.test_session() as _:
+              input_tensor = tf.constant(input_data, shape=input_shape)
+              output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
+                  input_tensor,
+                  pooling_ratio,
+                  pseudo_random=pseudo_random,
+                  overlapping=overlapping,
+                  deterministic=True,
+                  seed=self._SEED,
+                  seed2=self._SEED2)
+              output_data = output_tensor.eval()
+              output_shape = output_data.shape
+              # error_margin and delta setting is similar to max_pool_grad.
+              error_margin = 1e-3
+              gradient_error = tf.test.compute_gradient_error(
+                  input_tensor,
+                  input_shape,
+                  output_tensor,
+                  output_shape,
+                  x_init_value=input_data.reshape(input_shape),
+                  delta=1e-2)
+              self.assertLess(gradient_error, error_margin)
+
+  def testLargePoolingRatioThroughGradientError(self):
+    input_shape = (1, 17, 23, 1)
+    input_data = self._GenerateUniqueRandomInputTensor(input_shape)
+    # Add some randomness to make input_data not so 'integer'
+    input_data += self._PRNG.random_sample(input_shape)
+    pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
+    output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
+    overlapping = True
+    pseudo_random = False
+
+    with self.test_session() as _:
+      input_tensor = tf.constant(input_data, shape=input_shape)
+      output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
+          input_tensor,
+          pooling_ratio,
+          pseudo_random=pseudo_random,
+          overlapping=overlapping,
+          deterministic=True,
+          seed=self._SEED,
+          seed2=self._SEED2)
+      # error_margin and delta setting is similar to max_pool_grad.
+      error_margin = 1e-3
+      gradient_error = tf.test.compute_gradient_error(
+          input_tensor,
+          input_shape,
+          output_tensor,
+          output_shape,
+          x_init_value=input_data.reshape(input_shape),
+          delta=1e-2)
+      self.assertLess(gradient_error, error_margin)
+
+  def testWhenRepeatedMaxValueInPoolingRegion(self):
+    """Test when there's repeating value in pooling region.
+
+    There's no formal definition for what the gradient should be when there're
+    multiple max value within a pooling cell. Such as
+        | 1 5 |
+        | 5 3 |
+    The expected result depends heavily on implementation, if someone swap the
+    order of a nested for loop when walking through the tensor, result would be
+    very different.
+
+    The goal of this test is to alert when someone else change the
+    implementation. Current implementation scans row-by-row.
+    """
+    input_data = [5.0, 4.0, 6.0, 7.0,
+                  3.0, 5.0, 9.0, 6.0,
+                  8.0, 8.0, 9.0, 5.0,
+                  7.0, 4.0, 0.0, 0.0]  # pyformat: disable
+    input_size = [1, 4, 4, 1]
+    output_backprop = [12.0, 15.0,
+                       17.0, -5.0,
+                       6.0, 21.0]  # pyformat: disable
+    row_seq = [0, 1, 3, 4]
+    col_seq = [0, 2, 4]
+    output_data_not_overlapping = [5.0, 7.0,
+                                   8.0, 9.0,
+                                   7.0, 0.0]  # pyformat: disable
+    output_data_overlapping = [9.0, 9.0,
+                               9.0, 9.0,
+                               7.0, 0.0]  # pyformat: disable
+    output_size = [1, 3, 2, 1]
+    expected_input_backprop_not_overlapping = np.reshape(
+        [12.0, 0.0, 0.0, 15.0,
+         0.0, 0.0, -5.0, 0.0,
+         17.0, 0.0, 0.0, 0.0,
+         6.0, 0.0, 21.0, 0.0],
+        input_size)  # pyformat: disable
+    expected_input_backprop_overlapping = np.reshape(
+        [0.0, 0.0, 0.0, 0.0,
+         0.0, 0.0, 39.0, 0.0,
+         0.0, 0.0, 0.0, 0.0,
+         6.0, 0.0, 21.0, 0.0],
+        input_size)  # pyformat: disable
+    with self.test_session() as _:
+      # Test when overlapping is False
+      input_tensor = tf.constant(input_data, shape=input_size)
+      output_tensor = tf.constant(output_data_not_overlapping,
+                                  shape=output_size)
+      grad = tf.constant(output_backprop, shape=output_size)
+      r = gen_nn_ops._fractional_max_pool_grad(
+          input_tensor,
+          output_tensor,
+          grad,
+          row_seq,
+          col_seq,
+          overlapping=False)
+      input_backprop_not_overlapping = r.eval()
+      self.assertShapeEqual(
+          np.reshape(expected_input_backprop_not_overlapping, input_size), r)
+      self.assertAllClose(expected_input_backprop_not_overlapping,
+                          input_backprop_not_overlapping)
+      # Test when overlapping is True
+      output_tensor = tf.constant(output_data_overlapping, shape=output_size)
+      r = gen_nn_ops._fractional_max_pool_grad(
+          input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True)
+      input_backprop_overlapping = r.eval()
+      self.assertShapeEqual(
+          np.reshape(expected_input_backprop_overlapping, input_size), r)
+      self.assertAllClose(expected_input_backprop_overlapping,
+                          input_backprop_overlapping)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 3a94b9f4e39..e243720cabf 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -181,6 +181,8 @@ AvgPool
 MaxPool
 Softmax
 LogSoftmax
+FractionalAvgPoolGrad
+FractionalMaxPoolGrad
 
 # parsing_ops
 ParseExample
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 941bbfd271b..3c3c0cca41a 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -132,6 +132,8 @@ to the `Convolution` section for details about the padding calculation.
 @@max_pool_with_argmax
 @@avg_pool3d
 @@max_pool3d
+@@fractional_avg_pool
+@@fractional_max_pool
 
 ## Morphological filtering
 
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a396f06ebad..561ba6d5bb7 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -361,6 +361,53 @@ def _MaxPoolGrad(op, grad):
                                    data_format=op.get_attr("data_format"))
 
 
+@ops.RegisterGradient("FractionalMaxPool")
+def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
+  """Returns gradient for FractionalMaxPool.
+
+  Since FractionalMaxPool has three outputs, there are three gradients passed in
+  for each of the outputs. Only the first one is useful, the other two gradients
+  are empty.
+
+  Args:
+    op: The FractionalMaxPoolOp.
+    grad_0: Gradient with respect to op.outputs[0]
+    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
+    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
+
+  Returns:
+    Input backprop for FractionalMaxPool op.
+  """
+  # pylint: disable=protected-access
+  return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0],
+                                              grad_0, op.outputs[1],
+                                              op.outputs[2],
+                                              op.get_attr("overlapping"))
+
+
+@ops.RegisterGradient("FractionalAvgPool")
+def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
+  """Returns gradient for FractionalAvgPool.
+
+  Since FractionalAvgPool has three outputs, there are three gradients passed in
+  for each of the outputs. Only the first one is useful, the other two gradients
+  are empty.
+
+  Args:
+    op: The FractionalAvgPoolOp.
+    grad_0: Gradient with respect to op.outputs[0]
+    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
+    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
+
+  Returns:
+    Input backprop for FractionalAvgPool op.
+  """
+  # pylint: disable=protected-access
+  return gen_nn_ops._fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
+                                              op.outputs[1], op.outputs[2],
+                                              op.get_attr("overlapping"))
+
+
 @ops.RegisterGradient("BatchNormWithGlobalNormalization")
 def _BatchNormWithGlobalNormalizationGrad(op, grad):
   """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index f3805c3f2d7..f5f5aedc01a 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -941,6 +941,39 @@ def _AvgPoolGradShape(op):
     return [tensor_shape.unknown_shape(ndims=4)]
 
 
+@ops.RegisterShape("FractionalMaxPool")
+@ops.RegisterShape("FractionalAvgPool")
+def _fractional_pool_shape(op):
+  input_dims = op.inputs[0].get_shape().with_rank(4).as_list()
+  pooling_ratio = op.get_attr("pooling_ratio")
+  output_dims = np.divide(input_dims, pooling_ratio).astype(int)
+  return [
+      # output.
+      tensor_shape.TensorShape(output_dims),
+      # row_pooling_sequence.
+      tensor_shape.TensorShape([output_dims[1]]),
+      # col_pooling_sequence.
+      tensor_shape.TensorShape([output_dims[2]])
+  ]
+
+
+@ops.RegisterShape("FractionalMaxPoolGrad")
+def _fractional_max_pool_grad_shape(op):
+  """Shape function for the FractionalMaxPoolGrad op."""
+  orig_input_shape = op.inputs[0].get_shape().with_rank(4)
+  return [orig_input_shape]
+
+
+@ops.RegisterShape("FractionalAvgPoolGrad")
+def _fractional_avg_pool_grad_shape(op):
+  """Shape function for the FractionalAvgPoolGrad op."""
+  orig_input_shape = tensor_util.constant_value(op.inputs[0])
+  if orig_input_shape is not None:
+    return [tensor_shape.TensorShape(orig_input_shape.tolist())]
+  else:
+    return [tensor_shape.unknown_shape(ndims=4)]
+
+
 @ops.RegisterShape("Conv2DBackpropFilter")
 def _Conv2DBackpropFilterShape(op):
   """Shape function for the Conv2DBackpropFilter op."""