Adds fractional_max_pool and fractional_avg_pool ops. Fixes #2953.

Change: 131754627
This commit is contained in:
A. Unique TensorFlower 2016-08-30 13:23:06 -08:00 committed by TensorFlower Gardener
parent 79d8721bf2
commit 8b667b7d4b
17 changed files with 2627 additions and 0 deletions

View File

@ -1448,6 +1448,9 @@ tf_kernel_library(
srcs = [
"avgpooling_op.cc",
"cudnn_pooling_gpu.cc",
"fractional_avg_pool_op.cc",
"fractional_max_pool_op.cc",
"fractional_pool_common.cc",
"maxpooling_op.cc",
"pooling_ops_3d.cc",
"pooling_ops_common.cc",
@ -1455,6 +1458,7 @@ tf_kernel_library(
hdrs = [
"avgpooling_op.h",
"cudnn_pooling_gpu.h",
"fractional_pool_common.h",
"maxpooling_op.h",
"pooling_ops_common.h",
],

View File

@ -0,0 +1,354 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <algorithm>
#include <cmath>
#include <random>
#include <vector>
#include "tensorflow/core/kernels/fractional_pool_common.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/guarded_philox_random.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
class FractionalAvgPoolOp : public OpKernel {
public:
explicit FractionalAvgPoolOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
OP_REQUIRES(context, pooling_ratio_.size() == 4,
errors::InvalidArgument(
"pooling_ratio field must specify 4 dimensions"));
OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
errors::Unimplemented("Fractional average pooling is not yet "
"supported on the batch nor channel dimension."));
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
pooling_region_generated_ = false;
// Initialize philox random generator.
OP_REQUIRES_OK(context, generator_.Init(context));
}
void Compute(OpKernelContext* context) override {
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
constexpr int tensor_in_and_out_dims = 4;
const Tensor& tensor_in = context->input(0);
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size_.push_back(tensor_in.dim_size(i));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
output_size_.push_back(
static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
DCHECK_GT(output_size_[i], 0);
}
// Generate pooling sequence.
std::vector<int64> row_cum_seq;
std::vector<int64> col_cum_seq;
if (deterministic_) {
if (pooling_region_generated_) {
row_cum_seq = row_cum_seq_;
col_cum_seq = col_cum_seq_;
} else {
row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
&generator_, pseudo_random_);
col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
&generator_, pseudo_random_);
mutex_lock lock(mu_);
row_cum_seq_ = row_cum_seq;
col_cum_seq_ = col_cum_seq;
pooling_region_generated_ = true;
}
} else {
row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
&generator_, pseudo_random_);
col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
&generator_, pseudo_random_);
}
// Prepare output.
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
0, TensorShape({output_size_[0], output_size_[1],
output_size_[2], output_size_[3]}),
&output_tensor));
Tensor* output_row_seq_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
1, TensorShape({static_cast<int64>(row_cum_seq.size())}),
&output_row_seq_tensor));
Tensor* output_col_seq_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
&output_col_seq_tensor));
ConstEigenMatrixMap in_mat(
tensor_in.flat<T>().data(), input_size_[3],
input_size_[2] * input_size_[1] * input_size_[0]);
EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
output_size_[2] * output_size_[1] * output_size_[0]);
// out_count corresponds to number of elements in each pooling cell.
Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
// Initializes the output tensor and out_count with 0.
out_mat.setZero();
out_count.setZero();
auto output_row_seq_flat = output_row_seq_tensor->flat<int64>();
auto output_col_seq_flat = output_col_seq_tensor->flat<int64>();
// Set output tensors.
for (int i = 0; i < row_cum_seq.size(); ++i) {
output_row_seq_flat(i) = row_cum_seq[i];
}
for (int i = 0; i < col_cum_seq.size(); ++i) {
output_col_seq_flat(i) = col_cum_seq[i];
}
// For both input and output,
// 0: batch
// 1: row / row
// 2: col / col
// 3: depth / channel
const int64 row_max = input_size_[1] - 1;
const int64 col_max = input_size_[2] - 1;
for (int64 b = 0; b < input_size_[0]; ++b) {
// row sequence.
for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
// row start and end.
const int64 row_start = row_cum_seq[hs];
int64 row_end =
overlapping_ ? row_cum_seq[hs + 1] : row_cum_seq[hs + 1] - 1;
row_end = std::min(row_end, row_max);
// col sequence.
for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
const int64 out_offset =
(b * output_size_[1] + hs) * output_size_[2] + ws;
// col start and end.
const int64 col_start = col_cum_seq[ws];
int64 col_end =
overlapping_ ? col_cum_seq[ws + 1] : col_cum_seq[ws + 1] - 1;
col_end = std::min(col_end, col_max);
for (int64 h = row_start; h <= row_end; ++h) {
for (int64 w = col_start; w <= col_end; ++w) {
const int64 in_offset =
(b * input_size_[1] + h) * input_size_[2] + w;
out_mat.col(out_offset) += in_mat.col(in_offset);
out_count(out_offset)++;
}
}
}
}
}
DCHECK_GT(out_count.minCoeff(), 0);
out_mat.array().rowwise() /= out_count.transpose().array();
}
private:
bool deterministic_;
// meaningful only when deterministic_ is true.
mutex mu_;
std::vector<int64> row_cum_seq_;
std::vector<int64> col_cum_seq_;
bool pooling_region_generated_;
std::vector<int32> input_size_;
std::vector<int32> output_size_;
std::vector<float> pooling_ratio_;
bool pseudo_random_;
bool overlapping_;
GuardedPhiloxRandom generator_;
};
#define REGISTER_FRACTIONALAVGPOOL(type) \
REGISTER_KERNEL_BUILDER( \
Name("FractionalAvgPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
FractionalAvgPoolOp<type>)
REGISTER_FRACTIONALAVGPOOL(int32);
REGISTER_FRACTIONALAVGPOOL(int64);
REGISTER_FRACTIONALAVGPOOL(float);
REGISTER_FRACTIONALAVGPOOL(double);
#undef REGISTER_FRACTIONALAVGPOOL
template <class T>
class FractionalAvgPoolGradOp : public OpKernel {
public:
explicit FractionalAvgPoolGradOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
}
void Compute(OpKernelContext* context) override {
// Here's the basic idea:
// Batch and depth dimension are independent from row and col dimension. And
// because FractionalAvgPool currently only support pooling along row and
// col, we can basically think of this 4D tensor backpropagation as
// operation of a series of 2D planes.
//
// For each element of a 'slice' (2D plane) of output_backprop, we need to
// figure out its contributors when doing FractionalAvgPool operation. This
// can be done based on row_pooling_sequence, col_pooling_seq and
// overlapping.
// Once we figure out the original contributors, we just need to evenly
// divide the value of this element among these contributors.
//
// Internally, we divide the out_backprop tensor and store it in a temparary
// tensor of double type. And cast it to the corresponding type.
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
EigenDoubleMatrixMap;
// Grab the inputs.
const Tensor& orig_input_tensor_shape = context->input(0);
OP_REQUIRES(context, orig_input_tensor_shape.dims() == 1 &&
orig_input_tensor_shape.NumElements() == 4,
errors::InvalidArgument("original input tensor shape must be"
"1-dimensional and 4 elements"));
const Tensor& out_backprop = context->input(1);
const Tensor& row_seq_tensor = context->input(2);
const Tensor& col_seq_tensor = context->input(3);
const int64 out_batch = out_backprop.dim_size(0);
const int64 out_rows = out_backprop.dim_size(1);
const int64 out_cols = out_backprop.dim_size(2);
const int64 out_depth = out_backprop.dim_size(3);
auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
const int64 in_batch = orig_input_tensor_shape_flat(0);
const int64 in_rows = orig_input_tensor_shape_flat(1);
const int64 in_cols = orig_input_tensor_shape_flat(2);
const int64 in_depth = orig_input_tensor_shape_flat(3);
constexpr int tensor_in_and_out_dims = 4;
// Transform orig_input_tensor_shape into TensorShape
TensorShape in_shape;
for (auto i = 0; i < tensor_in_and_out_dims; ++i) {
in_shape.AddDim(orig_input_tensor_shape_flat(i));
}
// Create intermediate in_backprop.
Tensor in_backprop_tensor_temp;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<double>::v(), in_shape,
&in_backprop_tensor_temp));
in_backprop_tensor_temp.flat<double>().setZero();
// Transform 4D tensor to 2D matrix.
EigenDoubleMatrixMap in_backprop_tensor_temp_mat(
in_backprop_tensor_temp.flat<double>().data(), in_depth,
in_cols * in_rows * in_batch);
ConstEigenMatrixMap out_backprop_mat(out_backprop.flat<T>().data(),
out_depth,
out_cols * out_rows * out_batch);
// Loop through each element of out_backprop and evenly distribute the
// element to the corresponding pooling cell.
const int64 in_max_row_index = in_rows - 1;
const int64 in_max_col_index = in_cols - 1;
for (int64 b = 0; b < out_batch; ++b) {
for (int64 r = 0; r < out_rows; ++r) {
const int64 in_row_start = row_seq_tensor_flat(r);
int64 in_row_end = overlapping_ ? row_seq_tensor_flat(r + 1)
: row_seq_tensor_flat(r + 1) - 1;
in_row_end = std::min(in_row_end, in_max_row_index);
for (int64 c = 0; c < out_cols; ++c) {
const int64 in_col_start = col_seq_tensor_flat(c);
int64 in_col_end = overlapping_ ? col_seq_tensor_flat(c + 1)
: col_seq_tensor_flat(c + 1) - 1;
in_col_end = std::min(in_col_end, in_max_col_index);
const int64 num_elements_in_pooling_cell =
(in_row_end - in_row_start + 1) * (in_col_end - in_col_start + 1);
const int64 out_index = (b * out_rows + r) * out_cols + c;
// Now we can evenly distribute out_backprop(b, h, w, *) to
// in_backprop(b, hs:he, ws:we, *).
for (int64 in_r = in_row_start; in_r <= in_row_end; ++in_r) {
for (int64 in_c = in_col_start; in_c <= in_col_end; ++in_c) {
const int64 in_index = (b * in_rows + in_r) * in_cols + in_c;
// Walk through each channel (depth).
for (int64 d = 0; d < out_depth; ++d) {
const double out_backprop_element = static_cast<double>(
out_backprop_mat.coeffRef(d, out_index));
double& in_backprop_ref =
in_backprop_tensor_temp_mat.coeffRef(d, in_index);
in_backprop_ref +=
out_backprop_element / num_elements_in_pooling_cell;
}
}
}
}
}
}
// Depending on the type, cast double to type T.
Tensor* in_backprop_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, in_shape, &in_backprop_tensor));
auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>();
auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>();
for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) {
in_backprop_tensor_flat(i) =
static_cast<T>(in_backprop_tensor_temp_flat(i));
}
}
private:
bool overlapping_;
};
#define REGISTER_FRACTIONALAVGPOOLGRAD(type) \
REGISTER_KERNEL_BUILDER(Name("FractionalAvgPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
FractionalAvgPoolGradOp<type>)
REGISTER_FRACTIONALAVGPOOLGRAD(int32);
REGISTER_FRACTIONALAVGPOOLGRAD(int64);
REGISTER_FRACTIONALAVGPOOLGRAD(float);
REGISTER_FRACTIONALAVGPOOLGRAD(double);
#undef REGISTER_FRACTIONALAVGPOOLGRAD
} // namespace tensorflow

View File

@ -0,0 +1,381 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <algorithm>
#include <cmath>
#include <random>
#include <vector>
#include "tensorflow/core/kernels/fractional_pool_common.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/guarded_philox_random.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
class FractionalMaxPoolOp : public OpKernel {
public:
explicit FractionalMaxPoolOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
OP_REQUIRES(context, pooling_ratio_.size() == 4,
errors::InvalidArgument("pooling_ratio field must "
"specify 4 dimensions"));
OP_REQUIRES(
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
errors::Unimplemented("Fractional max pooling is not yet "
"supported on the batch nor channel dimension."));
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
pooling_region_generated_ = false;
// Initialize philox random generator.
OP_REQUIRES_OK(context, generator_.Init(context));
}
void Compute(OpKernelContext* context) override {
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
constexpr int tensor_in_and_out_dims = 4;
const Tensor& tensor_in = context->input(0);
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size_.push_back(tensor_in.dim_size(i));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
output_size_.push_back(
static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
DCHECK_GT(output_size_[i], 0);
}
// Generate pooling sequence.
std::vector<int64> height_cum_seq;
std::vector<int64> width_cum_seq;
if (deterministic_) {
if (pooling_region_generated_) {
height_cum_seq = height_cum_seq_;
width_cum_seq = width_cum_seq_;
} else {
height_cum_seq = GeneratePoolingSequence(
input_size_[1], output_size_[1], &generator_, pseudo_random_);
width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
&generator_, pseudo_random_);
mutex_lock lock(mu_);
height_cum_seq_ = height_cum_seq;
width_cum_seq_ = width_cum_seq;
pooling_region_generated_ = true;
}
} else {
height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
&generator_, pseudo_random_);
width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
&generator_, pseudo_random_);
}
// Prepare output.
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
0, TensorShape({output_size_[0], output_size_[1],
output_size_[2], output_size_[3]}),
&output_tensor));
Tensor* output_height_seq_tensor = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(
1, TensorShape({static_cast<int64>(height_cum_seq.size())}),
&output_height_seq_tensor));
Tensor* output_width_seq_tensor = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
&output_width_seq_tensor));
ConstEigenMatrixMap in_mat(
tensor_in.flat<T>().data(), input_size_[3],
input_size_[2] * input_size_[1] * input_size_[0]);
EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
output_size_[2] * output_size_[1] * output_size_[0]);
// Initializes the output tensor with MIN<T>.
output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
auto output_height_seq_flat = output_height_seq_tensor->flat<int64>();
auto output_width_seq_flat = output_width_seq_tensor->flat<int64>();
// Set output tensors.
for (int i = 0; i < height_cum_seq.size(); ++i) {
output_height_seq_flat(i) = height_cum_seq[i];
}
for (int i = 0; i < width_cum_seq.size(); ++i) {
output_width_seq_flat(i) = width_cum_seq[i];
}
// For both input and output,
// 0: batch
// 1: height / row
// 2: width / col
// 3: depth / channel
const int64 height_max = input_size_[1] - 1;
const int64 width_max = input_size_[2] - 1;
for (int64 b = 0; b < input_size_[0]; ++b) {
// height sequence.
for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
// height start and end.
const int64 height_start = height_cum_seq[hs];
int64 height_end =
overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
height_end = std::min(height_end, height_max);
// width sequence.
for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
const int64 out_offset =
(b * output_size_[1] + hs) * output_size_[2] + ws;
// width start and end.
const int64 width_start = width_cum_seq[ws];
int64 width_end =
overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
width_end = std::min(width_end, width_max);
for (int64 h = height_start; h <= height_end; ++h) {
for (int64 w = width_start; w <= width_end; ++w) {
const int64 in_offset =
(b * input_size_[1] + h) * input_size_[2] + w;
out_mat.col(out_offset) =
out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
}
}
}
}
}
}
private:
bool deterministic_;
// meaningful only when deterministic_ is true.
mutex mu_;
std::vector<int64> height_cum_seq_;
std::vector<int64> width_cum_seq_;
bool pooling_region_generated_;
std::vector<int32> input_size_;
std::vector<int32> output_size_;
std::vector<float> pooling_ratio_;
bool pseudo_random_;
bool overlapping_;
GuardedPhiloxRandom generator_;
};
#define REGISTER_FRACTIONALMAXPOOL(type) \
REGISTER_KERNEL_BUILDER( \
Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
FractionalMaxPoolOp<type>)
REGISTER_FRACTIONALMAXPOOL(int32);
REGISTER_FRACTIONALMAXPOOL(int64);
REGISTER_FRACTIONALMAXPOOL(float);
REGISTER_FRACTIONALMAXPOOL(double);
#undef REGISTER_FRACTIONALMAXPOOL
static const int kInvalidMaxPoolingIndex = -1;
template <class T>
class FractionalMaxPoolGradOp : public OpKernel {
public:
explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
}
void Compute(OpKernelContext* context) override {
// There are two steps when calculating gradient for FractionalMaxPool.
// 1) Walk through the process of calculating fractional pooling given
// pooling region; however, in the process, keep track of where the max
// element comes from. (arg_max)
// 2) Populate the value of out_backprop to where arg_max indicates. If
// we support overlapping, it is likely to have multiple out_backprop[i]
// propagates back to the same arg_max value.
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
EigenIndexMatrixMap;
const Tensor& tensor_in = context->input(0);
const Tensor& tensor_out = context->input(1);
const Tensor& out_backprop = context->input(2);
const Tensor& height_seq_tensor = context->input(3);
const Tensor& width_seq_tensor = context->input(4);
// Just to make it similar to FractionalMaxPoolOp.
constexpr int tensor_in_and_out_dims = 4;
std::vector<int64> input_size;
std::vector<int64> output_size;
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size.push_back(tensor_in.dim_size(i));
}
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
output_size.push_back(tensor_out.dim_size(i));
}
// ---------
// Step 1
// ---------
Tensor tensor_out_dup;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::v(),
tensor_out.shape(), &tensor_out_dup));
Tensor tensor_out_arg_max;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
tensor_out.shape(),
&tensor_out_arg_max));
// Find arg_max for each tensor_out
ConstEigenMatrixMap tensor_in_mat(
tensor_in.flat<T>().data(), input_size[3],
input_size[2] * input_size[1] * input_size[0]);
EigenMatrixMap tensor_out_dup_mat(
tensor_out_dup.flat<T>().data(), output_size[3],
output_size[2] * output_size[1] * output_size[0]);
EigenIndexMatrixMap tensor_out_arg_max_mat(
tensor_out_arg_max.flat<int64>().data(), output_size[3],
output_size[2] * output_size[1] * output_size[0]);
tensor_out_arg_max.flat<int64>().setConstant(kInvalidMaxPoolingIndex);
// Initializes the duplicate output tensor with MIN<T>.
tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
auto height_seq_tensor_flat = height_seq_tensor.flat<int64>();
auto width_seq_tensor_flat = width_seq_tensor.flat<int64>();
// Now walk through the process of fractional max pooling again.
// For both input and output,
// 0: batch
// 1: height / row
// 2: width / col
// 3: depth / channel
const int64 height_max = input_size[1] - 1;
const int64 width_max = input_size[2] - 1;
for (int64 b = 0; b < input_size[0]; ++b) {
// height sequence.
for (int64 hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
// height start and end.
const int64 height_start = height_seq_tensor_flat(hs);
int64 height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
: height_seq_tensor_flat(hs + 1) - 1;
height_end = std::min(height_end, height_max);
// width sequence.
for (int64 ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
const int64 out_index =
(b * output_size[1] + hs) * output_size[2] + ws;
// width start and end.
const int64 width_start = width_seq_tensor_flat(ws);
int64 width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
: width_seq_tensor_flat(ws + 1) - 1;
width_end = std::min(width_end, width_max);
for (int64 h = height_start; h <= height_end; ++h) {
for (int64 w = width_start; w <= width_end; ++w) {
const int64 in_index =
(b * input_size[1] + h) * input_size[2] + w;
// Walk through each channel (depth).
for (int64 d = 0; d < input_size[3]; ++d) {
const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
int64& out_arg_max_ref =
tensor_out_arg_max_mat.coeffRef(d, out_index);
if (output_ref < input_ref ||
out_arg_max_ref == kInvalidMaxPoolingIndex) {
output_ref = input_ref;
int input_offset = in_index * input_size[3] + d;
out_arg_max_ref = input_offset;
}
}
}
}
}
}
}
// Check tensor_out_dup is the same as tensor_out.
ConstEigenMatrixMap tensor_out_mat(
tensor_out.flat<T>().data(), output_size[3],
output_size[2] * output_size[1] * output_size[0]);
const int64 num_reshaped_cols =
output_size[2] * output_size[1] * output_size[0];
for (int64 i = 0; i < num_reshaped_cols; ++i) {
for (int64 j = 0; j < output_size[3]; ++j) {
DCHECK_EQ(tensor_out_dup_mat(j, i), tensor_out_mat(j, i));
}
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_in.shape(), &output));
output->flat<T>().setZero();
auto out_backprop_flat = out_backprop.flat<T>();
auto input_backprop_flat = output->flat<T>();
auto out_arg_max_flat = tensor_out_arg_max.flat<int64>();
int num_total_outputs = out_backprop_flat.size();
int num_total_inputs = input_backprop_flat.size();
for (int index = 0; index < num_total_outputs; ++index) {
int input_backprop_index = out_arg_max_flat(index);
// According to maxpooling_op.cc, the performance impact below is small.
CHECK(input_backprop_index >= 0 &&
input_backprop_index < num_total_inputs)
<< "Invalid input backprop index: " << input_backprop_index << ", "
<< num_total_inputs;
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
}
}
private:
bool overlapping_;
};
#define REGISTER_FRACTIONALMAXPOOLGRAD(type) \
REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
FractionalMaxPoolGradOp<type>)
REGISTER_FRACTIONALMAXPOOLGRAD(int32);
REGISTER_FRACTIONALMAXPOOLGRAD(int64);
REGISTER_FRACTIONALMAXPOOLGRAD(float);
REGISTER_FRACTIONALMAXPOOLGRAD(double);
#undef REGISTER_FRACTIONALMAXPOOLGRAD
} // namespace tensorflow

View File

@ -0,0 +1,134 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/fractional_pool_common.h"
#include "tensorflow/core/lib/random/simple_philox.h"
namespace tensorflow {
static std::vector<int64> GeneratePoolingSequencePseudoRandom(
int input_length, int output_length, GuardedPhiloxRandom* generator) {
std::vector<int64> cum_seq(output_length + 1, 0);
std::vector<int64> diff(output_length, 0);
double alpha = static_cast<double>(input_length) / output_length;
int k = input_length / output_length;
// In the paper [1], author proposes the following procedure to generate a
// pseudo random pooling region:
// 1) Set a_0 = 1, a_Nout = Nin;
// 2) a_i = ceil(alpha*(u+i))
// in which, i = 1, 2, ... , Nout-1
// u is a random number in (0,1) for all i
// alpha = Nin/Nout in (1,2)
// The sequence {a_i} should satisfy a_i-a_{i-1} = 1 or 2
// Note: for step 1), it makes more sense to make a_Nout = Nin+1, that way,
// a_i-a_{i-1} = 1 or 2 is also true for i = Nout.
//
// However, there are choices of alpha and u that will make
// a_i - a_{i-1} > 2. This happens at the left boundary. For example, with
// alpha = 1.732, u = 0.8, then a_1 = 4, a_1-a_0 = 3.
// This is why u_max1 is needed, i.e. u is a random number in (0,u_max1)
// instead of (0,1).
// Define k = ceil(alpha)-1, then we require:
// a_1 = alpha*(u+1) <= a_0+(k+1)
// ===> This gives u_max1 = (k+2)/alpha - 1.
//
// In addition, when extending the pooling sequence generation process for
// alpha beyond (1,2), e.g. (k,k+1); a check on the right boundary is also
// needed to make sure the last gap a_Nout-a_{Nout-1} >= k. Solving it gives
// u_max2 = (Nin+1-k)/alpha - (Nout-1)
// Here is an example where u > u_max2, alpha = 2.3, u = 0.7, u_max2 = 0.565,
// Nin = 23, Nout = 10; the sequence
// from a_0 to a_10 is:
// [1, 4, 7, 9, 11, 14, 16, 18, 21, 23, 24]
// The last gap is only 1.
//
// [1]: https://arxiv.org/abs/1412.6071
double u_max1 = (k + 2) / alpha - 1;
double u_max2 = (input_length + 1 - k) / alpha - (output_length - 1);
double max_u = std::min(u_max1, u_max2);
// Generate random number in parallel.
auto local_gen = generator->ReserveSamples32(2);
random::SimplePhilox random(&local_gen);
const double u = random.RandDouble() * max_u;
cum_seq[0] = 1;
cum_seq[output_length] = input_length + 1;
for (int i = 1; i < output_length; ++i) {
cum_seq[i] = static_cast<int>(ceil(alpha * (i + u)));
}
for (int i = 0; i < output_length; ++i) {
diff[i] = cum_seq[i + 1] - cum_seq[i];
}
return diff;
}
static std::vector<int64> GeneratePoolingSequenceRandom(
int input_length, int output_length, GuardedPhiloxRandom* generator) {
int k = input_length / output_length;
int num_random_spot = input_length % output_length;
std::vector<int64> diff(output_length, k);
for (int i = 0; i < num_random_spot; ++i) {
diff[i] += 1;
}
// Randomly shuffle this vector.
auto local_gen = generator->ReserveSamples32(diff.size());
random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen);
const auto uniform = [&single](uint32 n) { return single() % n; };
RandomShuffle(diff.begin(), diff.end(), uniform);
return diff;
}
std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
GuardedPhiloxRandom* generator,
bool pseudo_random) {
std::vector<int64> diff;
// This is a case that regular pooling can handle, just return diff with
// each element input_length/output_length.
if (input_length % output_length == 0) {
diff = std::vector<int64>(output_length, input_length / output_length);
}
if (pseudo_random) {
diff = GeneratePoolingSequencePseudoRandom(input_length, output_length,
generator);
} else {
diff =
GeneratePoolingSequenceRandom(input_length, output_length, generator);
}
// Sanity check.
int k = input_length / output_length;
for (int i = 0; i < output_length; ++i) {
// k<= diff[i] <= k+1.
DCHECK_GE(diff[i], k);
DCHECK_LE(diff[i], k + 1);
}
// Return cumulative sequence.
std::vector<int64> cum_seq(output_length + 1, 0);
for (int i = 1; i < cum_seq.size(); ++i) {
cum_seq[i] = cum_seq[i - 1] + diff[i - 1];
}
return cum_seq;
}
} // namespace tensorflow

View File

@ -0,0 +1,78 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
#include <algorithm>
#include <vector>
#include "tensorflow/core/util/guarded_philox_random.h"
namespace tensorflow {
// Shuffle a container randomly, copied from random_shuffle_op.cc
template <class Iter, class Random>
static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) {
if (first == last) {
return;
}
const auto stop = last - 1;
for (auto i = first; i != stop; ++i) {
using std::iter_swap;
iter_swap(i, i + uniform(last - i));
}
}
// Generate pooling sequence for fractional pooling along one dimension.
//
// Regular max/avg pooling can be viewed as a special case, in which given the
// * input_length: e.g. 10
// * output_length: e.g. 5
// it will generate pooling sequence as
// diff sequence: [2, 2, 2, 2, 2]
// or as
// cumulative sequence: [0, 2, 4, 6, 8, 10]
//
// In the case of fractional pooling, input_length is not an integer multiple of
// output_length, randomness plays a role when generating pooling sequence.
// There are two type of randomness (random vs pseudo-random) defined in paper:
// http://arxiv.org/abs/1412.6071
// You can check the paper for the difference between these two types.
//
// In summary, the generated diff sequence satisfy the following properties for
// both types of randomness:
// * length(generated_diff_pooling_sequence) = output_length
// * sum(generated_diff_pooling_sequence) = input_length
// * Let's define floor(input_length / output_length) = K, then
// K <= generated_diff_pooling_sequence[i] <= K+1
// For example, when input_length = 10, output_length = 6, the followings are
// valid pooling sequence:
// * [1, 2, 2, 1, 2, 2]
// * [1, 1, 2, 2, 2, 2]
// [1, 3, 2, 2, 2, 2] is not valid.
//
// Args:
// input_length: See above explanation
// output_length: See above explanation
// generator: Parallel version of random number generator
// pseudo_random: Whether or not use pseudo-random
// Returns:
// pooling_sequence: This is the cumulative pooling sequence.
std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
GuardedPhiloxRandom* generator,
bool pseudo_random);
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_

View File

@ -1515,4 +1515,205 @@ values: The `k` largest elements along each last dimensional slice.
indices: The indices of `values` within the last dimension of `input`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("FractionalMaxPool")
.Input("value: T")
.Output("output: T")
.Output("row_pooling_sequence: int64")
.Output("col_pooling_sequence: int64")
.Attr("pooling_ratio: list(float) >=4")
.Attr("pseudo_random: bool = false")
.Attr("overlapping: bool = false")
.Attr("deterministic: bool = false")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: {float, double, int32, int64}")
.Doc(R"doc(
Performs fractional max pooling on the input.
Fractional max pooling is slightly different than regular max pooling. In
regular max pooling, you downsize an input set by taking the maximum value of
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
a factor of N, where N is an integer. Fractional max pooling, as you might
expect from the word "fractional", means that the overall reduction ratio N
does not have to be an integer.
The sizes of the pooling regions are generated randomly but are fairly uniform.
For example, let's look at the height dimension, and the constraints on the
list of rows that will be pool boundaries.
First we define the following:
1. input_row_length : the number of rows from the input set
2. output_row_length : which will be smaller than the input
3. alpha = input_row_length / output_row_length : our reduction ratio
4. K = floor(alpha)
5. row_pooling_sequence : this is the result list of pool boundary rows
Then, row_pooling_sequence should satisfy:
1. a[0] = 0 : the first value of the sequence is 0
2. a[end] = input_row_length : the last value of the sequence is the size
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
4. length(row_pooling_sequence) = output_row_length+1
For more details on fractional max pooling, see this paper:
[Benjamin Graham, Fractional Max-Pooling]
(http://arxiv.org/abs/1412.6071)
value: 4-D with shape `[batch, height, width, channels]`.
pooling_ratio: Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
pseudo_random: When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
overlapping: When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [20, 16] for fractional max pooling.
deterministic: When set to True, a fixed pooling region will be used when
iterating over a FractionalMaxPool node in the computation graph. Mainly used
in unit test to make FractionalMaxPool deterministic.
seed: If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: An second seed to avoid seed collision.
output: output tensor after fractional max pooling.
row_pooling_sequence: row pooling sequence, needed to calculate gradient.
col_pooling_sequence: column pooling sequence, needed to calculate gradient.
)doc");
REGISTER_OP("FractionalMaxPoolGrad")
.Input("orig_input: T")
.Input("orig_output: T")
.Input("out_backprop: T")
.Input("row_pooling_sequence: int64")
.Input("col_pooling_sequence: int64")
.Output("output: T")
.Attr("overlapping: bool = false")
.Attr("T: {float, double, int32, int64}")
.Doc(R"doc(
Computes gradient of the FractionalMaxPool function.
orig_input: Original input for `fractional_max_pool`
orig_output: Original output for `fractional_max_pool`
out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients
w.r.t. the output of `fractional_max_pool`.
row_pooling_sequence: row pooling sequence, form pooling region with
col_pooling_sequence.
col_pooling_sequence: column pooling sequence, form pooling region with
row_pooling sequence.
overlapping: When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [20, 16] for fractional max pooling.
output: 4-D. Gradients w.r.t. the input of `fractional_max_pool`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("FractionalAvgPool")
.Input("value: T")
.Output("output: T")
.Output("row_pooling_sequence: int64")
.Output("col_pooling_sequence: int64")
.Attr("pooling_ratio: list(float) >=4")
.Attr("pseudo_random: bool = false")
.Attr("overlapping: bool = false")
.Attr("deterministic: bool = false")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: {float, double, int32, int64}")
.Doc(R"doc(
Performs fractional average pooling on the input.
Fractional average pooling is similar to Fractional max pooling in the pooling
region generation step. The only difference is that after pooling regions are
generated, a mean operation is performed instead of a max operation in each
pooling region.
value: 4-D with shape `[batch, height, width, channels]`.
pooling_ratio: Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
pseudo_random: When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
overlapping: When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [41/3, 26/3] for fractional avg pooling.
deterministic: When set to True, a fixed pooling region will be used when
iterating over a FractionalAvgPool node in the computation graph. Mainly used
in unit test to make FractionalAvgPool deterministic.
seed: If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: An second seed to avoid seed collision.
output: output tensor after fractional avg pooling.
row_pooling_sequence: row pooling sequence, needed to calculate gradient.
col_pooling_sequence: column pooling sequence, needed to calculate gradient.
)doc");
REGISTER_OP("FractionalAvgPoolGrad")
.Input("orig_input_tensor_shape: int64")
.Input("out_backprop: T")
.Input("row_pooling_sequence: int64")
.Input("col_pooling_sequence: int64")
.Output("output: T")
.Attr("overlapping: bool = false")
.Attr("T: {float, double, int32, int64}")
.Doc(R"doc(
Computes gradient of the FractionalAvgPool function.
Unlike FractionalMaxPoolGrad, we don't need to find arg_max for
FractionalAvgPoolGrad, we just need to evenly back-propagate each element of
out_backprop to those indices that form the same pooling cell. Therefore, we
just need to know the shape of original input tensor, instead of the whole
tensor.
orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool`
out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients
w.r.t. the output of `fractional_avg_pool`.
row_pooling_sequence: row pooling sequence, form pooling region with
col_pooling_sequence.
col_pooling_sequence: column pooling sequence, form pooling region with
row_pooling sequence.
overlapping: When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [41/3, 26/3] for fractional avg pooling.
output: 4-D. Gradients w.r.t. the input of `fractional_avg_pool`.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,82 @@
### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
Performs fractional max pooling on the input.
Fractional max pooling is slightly different than regular max pooling. In
regular max pooling, you downsize an input set by taking the maximum value of
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
a factor of N, where N is an integer. Fractional max pooling, as you might
expect from the word "fractional", means that the overall reduction ratio N
does not have to be an integer.
The sizes of the pooling regions are generated randomly but are fairly uniform.
For example, let's look at the height dimension, and the constraints on the
list of rows that will be pool boundaries.
First we define the following:
1. input_row_length : the number of rows from the input set
2. output_row_length : which will be smaller than the input
3. alpha = input_row_length / output_row_length : our reduction ratio
4. K = floor(alpha)
5. row_pooling_sequence : this is the result list of pool boundary rows
Then, row_pooling_sequence should satisfy:
1. a[0] = 0 : the first value of the sequence is 0
2. a[end] = input_row_length : the last value of the sequence is the size
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
4. length(row_pooling_sequence) = output_row_length+1
For more details on fractional max pooling, see this paper:
[Benjamin Graham, Fractional Max-Pooling]
(http://arxiv.org/abs/1412.6071)
##### Args:
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [20, 16] for fractional max pooling.
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
When set to True, a fixed pooling region will be used when
iterating over a FractionalMaxPool node in the computation graph. Mainly used
in unit test to make FractionalMaxPool deterministic.
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
An second seed to avoid seed collision.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.

View File

@ -0,0 +1,57 @@
### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
Performs fractional average pooling on the input.
Fractional average pooling is similar to Fractional max pooling in the pooling
region generation step. The only difference is that after pooling regions are
generated, a mean operation is performed instead of a max operation in each
pooling region.
##### Args:
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [41/3, 26/3] for fractional avg pooling.
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
When set to True, a fixed pooling region will be used when
iterating over a FractionalAvgPool node in the computation graph. Mainly used
in unit test to make FractionalAvgPool deterministic.
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
An second seed to avoid seed collision.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.

View File

@ -473,6 +473,8 @@
* [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse)
* [`erosion2d`](../../api_docs/python/nn.md#erosion2d)
* [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
* [`fractional_avg_pool`](../../api_docs/python/nn.md#fractional_avg_pool)
* [`fractional_max_pool`](../../api_docs/python/nn.md#fractional_max_pool)
* [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
* [`l2_loss`](../../api_docs/python/nn.md#l2_loss)
* [`l2_normalize`](../../api_docs/python/nn.md#l2_normalize)

View File

@ -828,6 +828,151 @@ Performs 3D max pooling on the input.
A `Tensor`. Has the same type as `input`. The max pooled output tensor.
- - -
### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
Performs fractional average pooling on the input.
Fractional average pooling is similar to Fractional max pooling in the pooling
region generation step. The only difference is that after pooling regions are
generated, a mean operation is performed instead of a max operation in each
pooling region.
##### Args:
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [41/3, 26/3] for fractional avg pooling.
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
When set to True, a fixed pooling region will be used when
iterating over a FractionalAvgPool node in the computation graph. Mainly used
in unit test to make FractionalAvgPool deterministic.
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
An second seed to avoid seed collision.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
- - -
### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
Performs fractional max pooling on the input.
Fractional max pooling is slightly different than regular max pooling. In
regular max pooling, you downsize an input set by taking the maximum value of
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
a factor of N, where N is an integer. Fractional max pooling, as you might
expect from the word "fractional", means that the overall reduction ratio N
does not have to be an integer.
The sizes of the pooling regions are generated randomly but are fairly uniform.
For example, let's look at the height dimension, and the constraints on the
list of rows that will be pool boundaries.
First we define the following:
1. input_row_length : the number of rows from the input set
2. output_row_length : which will be smaller than the input
3. alpha = input_row_length / output_row_length : our reduction ratio
4. K = floor(alpha)
5. row_pooling_sequence : this is the result list of pool boundary rows
Then, row_pooling_sequence should satisfy:
1. a[0] = 0 : the first value of the sequence is 0
2. a[end] = input_row_length : the last value of the sequence is the size
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
4. length(row_pooling_sequence) = output_row_length+1
For more details on fractional max pooling, see this paper:
[Benjamin Graham, Fractional Max-Pooling]
(http://arxiv.org/abs/1412.6071)
##### Args:
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
Pooling ratio for each dimension of `value`, currently only
supports row and col dimension and should be >= 1.0. For example, a valid
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
must be 1.0 because we don't allow pooling on batch and channels
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
respectively.
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
When set to True, generates the pooling sequence in a
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
difference between pseudorandom and random.
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
When set to True, it means when pooling, the values at the boundary
of adjacent pooling cells are used by both cells. For example:
`index 0 1 2 3 4`
`value 20 5 16 3 7`
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
The result would be [20, 16] for fractional max pooling.
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
When set to True, a fixed pooling region will be used when
iterating over a FractionalMaxPool node in the computation graph. Mainly used
in unit test to make FractionalMaxPool deterministic.
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
If either seed or seed2 are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
An second seed to avoid seed collision.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
## Morphological filtering

View File

@ -36,6 +36,8 @@ py_tests(
"determinant_op_test.py",
"edit_distance_op_test.py",
"fifo_queue_test.py",
"fractional_avg_pool_op_test.py",
"fractional_max_pool_op_test.py",
"identity_op_py_test.py",
"in_topk_op_test.py",
"io_ops_test.py",

View File

@ -0,0 +1,521 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fractional average pool operation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_nn_ops
class FractionalAvgTest(tf.test.TestCase):
# Random number generate with seed.
_PRNG = np.random.RandomState(341261000)
_SEED = 341261001
_SEED2 = 341261002
def _AvgPoolAlongRows(self, input_matrix, row_seq, overlapping):
"""Perform average pool along row of a 2-D matrix based on row_seq.
Args:
input_matrix: A 2-D matrix.
row_seq: Cumulative pooling sequence along row.
overlapping: Whether or not use overlapping when pooling.
Returns:
A 2-D matrix, with
* num_rows = len(row_seq)-1
* num_cols = input_matrix.num_cols.
"""
output_image = np.zeros(input_matrix.shape[1])
row_max = row_seq[-1]
for i in range(row_seq.shape[0] - 1):
row_start = row_seq[i]
row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
row_end = min(row_end, row_max)
output_image = np.vstack((output_image,
np.mean(input_matrix[row_start:row_end, :],
axis=0))) # axis 0 is along row
# remove the sentinel row
return output_image[1:, :]
def _AvgPoolAlongCols(self, input_matrix, col_seq, overlapping):
"""Perform average pool along column of a 2-D matrix based on col_seq.
Args:
input_matrix: A 2-D matrix.
col_seq: Cumulative pooling sequence along column.
overlapping: Whether or not use overlapping when pooling.
Returns:
A 2-D matrix, with
* num_rows = input_matrix.num_rows
* num_cols = len(col_seq)-1.
"""
input_matrix = input_matrix.transpose()
output_matrix = self._AvgPoolAlongRows(input_matrix, col_seq, overlapping)
return output_matrix.transpose()
def _GetExpectedFractionalAvgPoolResult(self, input_tensor, row_seq, col_seq,
overlapping):
"""Get expected fractional average pooling result.
row_seq and col_seq together defines the fractional pooling region.
Args:
input_tensor: Original input tensor, assuming it is a 4-D tensor, with
dimension as [batch, height/row, width/column, channels/depth].
row_seq: Cumulative pooling sequence along row.
col_seq: Cumulative pooling sequence along column.
overlapping: Use overlapping when doing pooling.
Returns:
A 4-D tensor that is the result of average pooling on input_tensor based
on pooling region defined by row_seq and col_seq, conditioned on whether
or not overlapping is used.
"""
input_shape = input_tensor.shape
output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
input_shape[3])
output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
for batch in range(input_shape[0]):
for channel in range(input_shape[3]):
two_dim_slice = input_tensor[batch, :, :, channel]
tmp = self._AvgPoolAlongRows(two_dim_slice, row_seq, overlapping)
output_tensor[batch, :, :, channel] = self._AvgPoolAlongCols(
tmp, col_seq, overlapping)
return output_tensor
def _ValidateFractionalAvgPoolResult(self, input_tensor, pooling_ratio,
pseudo_random, overlapping):
"""Validate FractionalAvgPool's result against expected.
Expected result is computed given input_tensor, and pooling region defined
by row_seq and col_seq.
Args:
input_tensor: A tensor or numpy ndarray.
pooling_ratio: A list or tuple of length 4, first and last element be 1.
pseudo_random: Use pseudo random method to generate pooling sequence.
overlapping: Use overlapping when pooling.
Returns:
None
"""
with self.test_session() as sess:
p, r, c = tf.nn.fractional_avg_pool(input_tensor,
pooling_ratio,
pseudo_random,
overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
actual, row_seq, col_seq = sess.run([p, r, c])
expected = self._GetExpectedFractionalAvgPoolResult(input_tensor, row_seq,
col_seq, overlapping)
self.assertShapeEqual(expected, p)
self.assertAllClose(expected, actual)
def _testVisually(self):
"""Manual test by printing out intermediate result of a small random tensor.
Since _GetExpectedFractionalAvgPoolResult is 'automated', it feels safer to
have a test case that you can see what's happening.
This test will generate a small, random, int 2D matrix, and feed it to
FractionalAvgPool and _GetExpectedFractionalAvgPoolResult.
"""
num_rows = 6
num_cols = 6
tensor_shape = (1, num_rows, num_cols, 1)
pseudo_random = False
for overlapping in True, False:
print("-" * 70)
print("Testing FractionalAvgPool with overlapping = {}".format(
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
with self.test_session() as sess:
p, r, c = tf.nn.fractional_avg_pool(
rand_mat.astype(np.float32),
pooling_ratio,
pseudo_random,
overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
tensor_output, row_seq, col_seq = sess.run([p, r, c])
expected_result = self._GetExpectedFractionalAvgPoolResult(
rand_mat.astype(np.float32), row_seq, col_seq, overlapping)
print("row sequence:")
print(row_seq)
print("column sequence:")
print(col_seq)
print("Input:")
# Print input with pooling region marked.
for i in range(num_rows):
row_to_print = []
for j in range(num_cols):
if j in col_seq:
row_to_print.append("|")
row_to_print.append(str(rand_mat[0, i, j, 0]))
row_to_print.append("|")
if i in row_seq:
print("-" * 2 * len(row_to_print))
print(" ".join(row_to_print))
print("-" * 2 * len(row_to_print))
print("Output from FractionalAvgPool:")
print(tensor_output[0, :, :, 0])
print("Expected result:")
print(expected_result[0, :, :, 0])
def testAllInputOptions(self):
"""Try all possible input options for fractional_avg_pool.
"""
num_batches = 5
num_channels = 3
num_rows = 20
num_cols = 30
for pseudo_random in True, False:
for overlapping in True, False:
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalAvgPoolResult(
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
overlapping)
def testIntegerTensorInput(self):
"""Test FractionalAvgPool works fine when input tensor is integer type.
I would have used _ValidateFractionalAvgPoolResult function to automate this
process, however, there's rounding issue. It is caused by numpy.mean cast
integer input to numpy.float64 for intermediate use. While for
fractional_avg_pool, the mean operation is integer division (trucated). So,
for this test case, I will hard code a simple matrix.
"""
pseudo_random = True
overlapping = True
tensor_shape = (1, 6, 6, 1)
# pyformat: disable
mat = np.array([
[2, 6, 4, 1, 3, 6],
[8, 9, 1, 6, 6, 8],
[3, 9, 8, 2, 5, 6],
[2, 7, 9, 5, 4, 5],
[8, 5, 0, 5, 7, 4],
[4, 4, 5, 9, 7, 2]
])
# pyformat: enable
with self.test_session() as sess:
# Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
# are the same each time. We can have an expected result precomputed.
# r = [0, 2, 4, 6]
# c = [0, 1, 3, 4, 6]
# pyformat: disable
expected = np.array([
[6, 5, 3, 5],
[5, 5, 4, 5],
[5, 4, 7, 5]
]).reshape((1, 3, 4, 1))
# pyformat: enable
p, unused_r, unused_c = tf.nn.fractional_avg_pool(
mat.reshape(tensor_shape), [1, math.sqrt(3), math.sqrt(2), 1],
pseudo_random,
overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
actual = sess.run(p)
self.assertShapeEqual(expected, p)
self.assertAllClose(expected, actual)
def testDifferentTensorShapes(self):
"""Test different shapes of input tensor.
Mainly test different combinations of num_rows and num_cols.
"""
pseudo_random = True
overlapping = True
for num_batches in [1, 3]:
for num_channels in [1, 3]:
for num_rows in [10, 20, 50]:
for num_cols in [10, 20, 50]:
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalAvgPoolResult(
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
overlapping)
def testLargePoolingRatio(self):
"""Test when pooling ratio is not within [1, 2).
"""
pseudo_random = True
overlapping = True
num_batches = 3
num_channels = 3
num_rows = 30
num_cols = 50
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
for row_ratio in [math.sqrt(11), math.sqrt(37)]:
for col_ratio in [math.sqrt(11), math.sqrt(27)]:
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalAvgPoolResult(rand_mat,
[1, row_ratio, col_ratio, 1],
pseudo_random, overlapping)
def testDivisiblePoolingRatio(self):
"""Test when num of rows/cols can evenly divide pooling ratio.
This is a case regular average pooling can handle. Should be handled by
fractional pooling as well.
"""
pseudo_random = True
overlapping = True
num_batches = 3
num_channels = 3
num_rows = 30
num_cols = 50
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
overlapping)
class FractionalAvgPoolGradTest(tf.test.TestCase):
"""Tests for FractionalAvgPoolGrad.
Two types of tests for FractionalAvgPoolGrad.
1) Test fractional_avg_pool_grad() directly.
This type of test relies on gen_nn_ops._avg_pool_grad() returns the
correct result. For example:
* input_tensor_shape = (1, 10, 10, 1)
* window_size = (1, 2, 2, 1)
* stride_size = (1, 2, 2, 1)
* padding: not really important, since 10/2 is divisible
avg pooling should generate the same result as fractional avg pooling with:
* row_sequence = [0, 2, 4, 6, 8, 10]
* col_sequence = [0, 2, 4, 6, 8, 10]
* overlapping = False
This also means their gradients in such case will be the same.
Similarly, when
* input_tensor_shape = (1, 7, 7, 1)
* window_size = (1, 3, 3, 1)
* stride_size = (1, 2, 2, 1)
* padding: not important
avg pooling should generate the same result as fractional avg pooling with:
* row_sequence = [0, 2, 4, 7]
* col_sequence = [0, 2, 4, 7]
* overlapping = True
2) Test through compute_gradient_error()
"""
_PRNG = np.random.RandomState(341261004)
_SEED = 341261005
_SEED2 = 341261006
def _GenerateRandomInputTensor(self, shape):
num_elements = 1
for dim_size in shape:
num_elements *= dim_size
x = self._PRNG.rand(num_elements) * 1000
return x.reshape(shape)
def testDirectNotUseOverlapping(self):
for num_batches in [1, 3]:
for row_window_size in [2, 5]:
for col_window_size in [2, 4]:
num_rows = row_window_size * 5
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
with self.test_session() as _:
input_tensor = tf.constant(self._GenerateRandomInputTensor(
input_shape).astype(np.float32))
window_size = [1, row_window_size, col_window_size, 1]
stride_size = [1, row_window_size, col_window_size, 1]
padding = "VALID"
output_tensor = tf.nn.avg_pool(input_tensor, window_size,
stride_size, padding)
output_data = output_tensor.eval()
num_elements = 1
for dim_size in output_data.shape:
num_elements *= dim_size
output_backprop = (self._PRNG.rand(num_elements) *
1000).reshape(output_data.shape)
input_backprop_tensor = gen_nn_ops._avg_pool_grad(
input_tensor.get_shape(), output_backprop, window_size,
stride_size, padding)
input_backprop = input_backprop_tensor.eval()
row_seq = list(range(0, num_rows + 1, row_window_size))
col_seq = list(range(0, num_cols + 1, col_window_size))
fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
input_tensor.get_shape(),
output_backprop,
row_seq,
col_seq,
overlapping=False)
fap_input_backprop = fap_input_backprop_tensor.eval()
self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
self.assertAllClose(input_backprop, fap_input_backprop)
def testDirectUseOverlapping(self):
for num_batches in [1, 3]:
for row_window_size in [2, 5]:
for col_window_size in [2, 4]:
num_rows = (row_window_size - 1) * 5 + 1
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
with self.test_session() as _:
input_tensor = tf.constant(self._GenerateRandomInputTensor(
input_shape).astype(np.float32))
window_size = [1, row_window_size, col_window_size, 1]
stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
padding = "VALID"
output_tensor = tf.nn.avg_pool(input_tensor, window_size,
stride_size, padding)
output_data = output_tensor.eval()
num_elements = 1
for dim_size in output_data.shape:
num_elements *= dim_size
output_backprop = (self._PRNG.rand(num_elements) *
1000).reshape(output_data.shape)
input_backprop_tensor = gen_nn_ops._avg_pool_grad(
input_tensor.get_shape(), output_backprop, window_size,
stride_size, padding)
input_backprop = input_backprop_tensor.eval()
row_seq = list(range(0, num_rows, row_window_size - 1))
col_seq = list(range(0, num_cols, col_window_size - 1))
row_seq[-1] += 1
col_seq[-1] += 1
fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
input_tensor.get_shape(),
output_backprop,
row_seq,
col_seq,
overlapping=True)
fap_input_backprop = fap_input_backprop_tensor.eval()
self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
self.assertAllClose(input_backprop, fap_input_backprop)
def testAllInputOptionsThroughGradientError(self):
input_shape = (1, 7, 13, 1)
input_data = self._GenerateRandomInputTensor(input_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
for pseudo_random in True, False:
for overlapping in True, False:
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
output_data = output_tensor.eval()
output_shape = output_data.shape
# error_margin and delta setting is similar to avg_pool_grad.
error_margin = 1e-4
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
def testDifferentTensorShapesThroughGradientError(self):
pseudo_random = True
overlapping = True
pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
for num_batches in [1, 2]:
for num_rows in [5, 13]:
for num_cols in [5, 11]:
for num_channels in [1, 3]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
input_data = self._GenerateRandomInputTensor(input_shape)
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
output_data = output_tensor.eval()
output_shape = output_data.shape
# error_margin and delta setting is similar to avg_pool_grad.
error_margin = 1e-4
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
def testLargePoolingRatioThroughGradientError(self):
input_shape = (1, 17, 23, 1)
input_data = self._GenerateRandomInputTensor(input_shape)
pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
overlapping = True
pseudo_random = False
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
# error_margin and delta setting is similar to avg_pool_grad.
error_margin = 1e-4
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,582 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fractional max pool operation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_nn_ops
class FractionalMaxPoolTest(tf.test.TestCase):
# Random number generate with seed.
_PRNG = np.random.RandomState(341261)
_SEED = 123456
_SEED2 = 654321
def _MaxPoolAlongRows(self, input_matrix, row_seq, overlapping):
"""Perform max pool along row of a 2-D matrix based on row_seq.
Args:
input_matrix: A 2-D matrix.
row_seq: Cumulative pooling sequence along row.
overlapping: Whether or not use overlapping when pooling.
Returns:
A 2-D matrix, with
* num_rows = len(row_seq)-1
* num_cols = input_matrix.num_cols.
"""
output_image = np.zeros(input_matrix.shape[1])
row_max = row_seq[-1]
for i in range(row_seq.shape[0] - 1):
row_start = row_seq[i]
row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
row_end = min(row_end, row_max)
output_image = np.vstack((output_image,
np.amax(input_matrix[row_start:row_end, :],
axis=0))) # axis 0 is along row
# remove the sentinel row
return output_image[1:, :]
def _MaxPoolAlongCols(self, input_matrix, col_seq, overlapping):
"""Perform max pool along column of a 2-D matrix based on col_seq.
Args:
input_matrix: A 2-D matrix.
col_seq: Cumulative pooling sequence along column.
overlapping: Whether or not use overlapping when pooling.
Returns:
A 2-D matrix, with
* num_rows = input_matrix.num_rows
* num_cols = len(col_seq)-1.
"""
input_matrix = input_matrix.transpose()
output_matrix = self._MaxPoolAlongRows(input_matrix, col_seq, overlapping)
return output_matrix.transpose()
def _GetExpectedFractionalMaxPoolResult(self, input_tensor, row_seq, col_seq,
overlapping):
"""Get expected fractional max pool result.
row_seq and col_seq together defines the fractional pooling region.
Args:
input_tensor: Original input tensor, assuming it is a 4-D tensor, with
dimension as [batch, height/row, width/column, channels/depth].
row_seq: Cumulative pooling sequence along row.
col_seq: Cumulative pooling sequence along column.
overlapping: Use overlapping when doing pooling.
Returns:
A 4-D tensor that is the result of max pooling on input_tensor based on
pooling region defined by row_seq and col_seq, conditioned on whether or
not overlapping is used.
"""
input_shape = input_tensor.shape
output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
input_shape[3])
output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
for batch in range(input_shape[0]):
for channel in range(input_shape[3]):
two_dim_slice = input_tensor[batch, :, :, channel]
tmp = self._MaxPoolAlongRows(two_dim_slice, row_seq, overlapping)
output_tensor[batch, :, :, channel] = self._MaxPoolAlongCols(
tmp, col_seq, overlapping)
return output_tensor
def _ValidateFractionalMaxPoolResult(self, input_tensor, pooling_ratio,
pseudo_random, overlapping):
"""Validate FractionalMaxPool's result against expected.
Expected result is computed given input_tensor, and pooling region defined
by row_seq and col_seq.
Args:
input_tensor: A tensor or numpy ndarray.
pooling_ratio: A list or tuple of length 4, first and last element be 1.
pseudo_random: Use pseudo random method to generate pooling sequence.
overlapping: Use overlapping when pooling.
Returns:
None
"""
with self.test_session() as sess:
p, r, c = tf.nn.fractional_max_pool(input_tensor,
pooling_ratio,
pseudo_random,
overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
actual, row_seq, col_seq = sess.run([p, r, c])
expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq,
col_seq, overlapping)
self.assertShapeEqual(expected, p)
self.assertAllClose(expected, actual)
def _testVisually(self):
"""Manual test by printing out intermediate result of a small random tensor.
Since _GetExpectedFractionalMaxPoolResult is 'automated', it feel safer to
have a test case that you can see what's happening.
This test will generate a small, random, int 2D matrix, and feed it to
FractinalMaxPool and _GetExpectedFractionalMaxPoolResult.
"""
num_rows = 6
num_cols = 6
tensor_shape = (1, num_rows, num_cols, 1)
pseudo_random = False
for overlapping in True, False:
print("-" * 70)
print("Testing FractionalMaxPool with overlapping = {}".format(
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
with self.test_session() as sess:
p, r, c = tf.nn.fractional_max_pool(rand_mat,
pooling_ratio,
pseudo_random,
overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
tensor_output, row_seq, col_seq = sess.run([p, r, c])
expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat,
row_seq,
col_seq,
overlapping)
print("row sequence:")
print(row_seq)
print("column sequence:")
print(col_seq)
print("Input:")
# Print input with pooling region marked.
for i in range(num_rows):
row_to_print = []
for j in range(num_cols):
if j in col_seq:
row_to_print.append("|")
row_to_print.append(str(rand_mat[0, i, j, 0]))
row_to_print.append("|")
if i in row_seq:
print("-" * 2 * len(row_to_print))
print(" ".join(row_to_print))
print("-" * 2 * len(row_to_print))
print("Output from FractionalMaxPool:")
print(tensor_output[0, :, :, 0])
print("Expected result:")
print(expected_result[0, :, :, 0])
def testAllInputOptions(self):
"""Try all possible input options for fractional_max_pool.
"""
num_batches = 5
num_channels = 3
num_rows = 20
num_cols = 30
for pseudo_random in True, False:
for overlapping in True, False:
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalMaxPoolResult(
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
overlapping)
def testIntegerTensorInput(self):
"""Test it works fine when input tensor is integer type.
"""
num_batches = 5
num_channels = 3
num_rows = 20
num_cols = 30
pseudo_random = True
overlapping = True
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
rand_mat = self._PRNG.randint(1000, size=tensor_shape)
self._ValidateFractionalMaxPoolResult(rand_mat,
[1, math.sqrt(3), math.sqrt(2), 1],
pseudo_random, overlapping)
def testDifferentTensorShapes(self):
"""Test different shapes of input tensor.
Mainly test different combinations of num_rows and num_cols.
"""
pseudo_random = True
overlapping = True
for num_batches in [1, 3]:
for num_channels in [1, 3]:
for num_rows in [10, 20, 50]:
for num_cols in [10, 20, 50]:
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalMaxPoolResult(
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
overlapping)
def testLargePoolingRatio(self):
"""Test when pooling ratio is not within [1, 2).
"""
pseudo_random = True
overlapping = True
num_batches = 3
num_channels = 3
num_rows = 30
num_cols = 50
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
for row_ratio in [math.sqrt(11), math.sqrt(37)]:
for col_ratio in [math.sqrt(11), math.sqrt(27)]:
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalMaxPoolResult(rand_mat,
[1, row_ratio, col_ratio, 1],
pseudo_random, overlapping)
def testDivisiblePoolingRatio(self):
"""Test when num of rows/cols can evenly divide pooling ratio.
This is a case regular max pooling can handle. Should be handled by
fractional pooling as well.
"""
pseudo_random = True
overlapping = True
num_batches = 3
num_channels = 3
num_rows = 30
num_cols = 50
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
# random tensor with value in [-500.0, 500.0)
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
overlapping)
class FractionalMaxPoolGradTest(tf.test.TestCase):
"""Tests for FractionalMaxPoolGrad.
Two types of tests for FractionalMaxPoolGrad.
1) Test fractional_max_pool_grad() directly.
This type of test relies on gen_nn_ops._max_pool_grad() returns the correct
result. For example:
* input_tensor_shape = (1, 10, 10, 1)
* window_size = (1, 2, 2, 1)
* stride_size = (1, 2, 2, 1)
* padding: not really import, since 10/2 is divisible
max pooling should generate the same result as fractional max pooling with:
* row_sequence = [0, 2, 4, 6, 8, 10]
* col_sequence = [0, 2, 4, 6, 8, 10]
* overlapping = False
This also means their gradients in such case will be the same.
Similarly, when
* input_tensor_shape = (1, 7, 7, 1)
* window_size = (1, 3, 3, 1)
* stride_size = (1, 2, 2, 1)
* padding: not important
max pooling should generate the same result as fractional max pooling with:
* row_sequence = [0, 2, 4, 7]
* col_sequence = [0, 2, 4, 7]
* overlapping = True
2) Test through compute_gradient_error()
"""
_PRNG = np.random.RandomState(341261)
_SEED = 123456
_SEED2 = 654321
def _GenerateUniqueRandomInputTensor(self, shape):
"""Generate 'unqiue' random input tensor.
'Unique' means there's no collision values in the tensor, all elements are
different. This is done by generating sequence of integers with step of 1
and then randomly shuffle these integers.
Args:
shape: Shape of the tensor desired.
Returns:
A numpy ndarray with size = shape and dtype = numpy.float32.
"""
num_elements = 1
for size in shape:
num_elements *= size
x = np.arange(num_elements, dtype=np.float32)
self._PRNG.shuffle(x)
return x.reshape(shape)
def testDirectNotUseOverlapping(self):
for num_batches in [1, 3]:
for row_window_size in [2, 5]:
for col_window_size in [2, 4]:
num_rows = row_window_size * 5
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
with self.test_session() as _:
input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
input_shape))
window_size = [1, row_window_size, col_window_size, 1]
stride_size = [1, row_window_size, col_window_size, 1]
padding = "VALID"
output_tensor = tf.nn.max_pool(input_tensor, window_size,
stride_size, padding)
output_data = output_tensor.eval()
output_backprop = self._PRNG.randint(100, size=output_data.shape)
input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
output_tensor,
output_backprop,
window_size,
stride_size,
padding)
input_backprop = input_backprop_tensor.eval()
row_seq = list(range(0, num_rows + 1, row_window_size))
col_seq = list(range(0, num_cols + 1, col_window_size))
fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
input_tensor,
output_tensor,
output_backprop,
row_seq,
col_seq,
overlapping=False)
fmp_input_backprop = fmp_input_backprop_tensor.eval()
self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
self.assertAllClose(input_backprop, fmp_input_backprop)
def testDirectUseOverlapping(self):
for num_batches in [1, 3]:
for row_window_size in [2, 5]:
for col_window_size in [2, 4]:
num_rows = (row_window_size - 1) * 5 + 1
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
with self.test_session() as _:
input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
input_shape))
window_size = [1, row_window_size, col_window_size, 1]
stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
padding = "VALID"
output_tensor = tf.nn.max_pool(input_tensor, window_size,
stride_size, padding)
output_data = output_tensor.eval()
output_backprop = self._PRNG.randint(100, size=output_data.shape)
input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
output_tensor,
output_backprop,
window_size,
stride_size,
padding)
input_backprop = input_backprop_tensor.eval()
row_seq = list(range(0, num_rows, row_window_size - 1))
col_seq = list(range(0, num_cols, col_window_size - 1))
row_seq[-1] += 1
col_seq[-1] += 1
fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
input_tensor,
output_tensor,
output_backprop,
row_seq,
col_seq,
overlapping=True)
fmp_input_backprop = fmp_input_backprop_tensor.eval()
self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
self.assertAllClose(input_backprop, fmp_input_backprop)
def testAllInputOptionsThroughGradientError(self):
input_shape = (1, 7, 13, 1)
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
# Add some randomness to make input_data not so 'integer'
input_data += self._PRNG.random_sample(input_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
for pseudo_random in True, False:
for overlapping in True, False:
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
output_data = output_tensor.eval()
output_shape = output_data.shape
# error_margin and delta setting is similar to max_pool_grad.
error_margin = 1e-3
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
def testDifferentTensorShapesThroughGradientError(self):
pseudo_random = True
overlapping = True
pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
for num_batches in [1, 2]:
for num_rows in [5, 13]:
for num_cols in [5, 11]:
for num_channels in [1, 3]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
# Add some randomness to make input_data not so 'integer'
input_data += self._PRNG.random_sample(input_shape)
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
output_data = output_tensor.eval()
output_shape = output_data.shape
# error_margin and delta setting is similar to max_pool_grad.
error_margin = 1e-3
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
def testLargePoolingRatioThroughGradientError(self):
input_shape = (1, 17, 23, 1)
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
# Add some randomness to make input_data not so 'integer'
input_data += self._PRNG.random_sample(input_shape)
pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
overlapping = True
pseudo_random = False
with self.test_session() as _:
input_tensor = tf.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
input_tensor,
pooling_ratio,
pseudo_random=pseudo_random,
overlapping=overlapping,
deterministic=True,
seed=self._SEED,
seed2=self._SEED2)
# error_margin and delta setting is similar to max_pool_grad.
error_margin = 1e-3
gradient_error = tf.test.compute_gradient_error(
input_tensor,
input_shape,
output_tensor,
output_shape,
x_init_value=input_data.reshape(input_shape),
delta=1e-2)
self.assertLess(gradient_error, error_margin)
def testWhenRepeatedMaxValueInPoolingRegion(self):
"""Test when there's repeating value in pooling region.
There's no formal definition for what the gradient should be when there're
multiple max value within a pooling cell. Such as
| 1 5 |
| 5 3 |
The expected result depends heavily on implementation, if someone swap the
order of a nested for loop when walking through the tensor, result would be
very different.
The goal of this test is to alert when someone else change the
implementation. Current implementation scans row-by-row.
"""
input_data = [5.0, 4.0, 6.0, 7.0,
3.0, 5.0, 9.0, 6.0,
8.0, 8.0, 9.0, 5.0,
7.0, 4.0, 0.0, 0.0] # pyformat: disable
input_size = [1, 4, 4, 1]
output_backprop = [12.0, 15.0,
17.0, -5.0,
6.0, 21.0] # pyformat: disable
row_seq = [0, 1, 3, 4]
col_seq = [0, 2, 4]
output_data_not_overlapping = [5.0, 7.0,
8.0, 9.0,
7.0, 0.0] # pyformat: disable
output_data_overlapping = [9.0, 9.0,
9.0, 9.0,
7.0, 0.0] # pyformat: disable
output_size = [1, 3, 2, 1]
expected_input_backprop_not_overlapping = np.reshape(
[12.0, 0.0, 0.0, 15.0,
0.0, 0.0, -5.0, 0.0,
17.0, 0.0, 0.0, 0.0,
6.0, 0.0, 21.0, 0.0],
input_size) # pyformat: disable
expected_input_backprop_overlapping = np.reshape(
[0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 39.0, 0.0,
0.0, 0.0, 0.0, 0.0,
6.0, 0.0, 21.0, 0.0],
input_size) # pyformat: disable
with self.test_session() as _:
# Test when overlapping is False
input_tensor = tf.constant(input_data, shape=input_size)
output_tensor = tf.constant(output_data_not_overlapping,
shape=output_size)
grad = tf.constant(output_backprop, shape=output_size)
r = gen_nn_ops._fractional_max_pool_grad(
input_tensor,
output_tensor,
grad,
row_seq,
col_seq,
overlapping=False)
input_backprop_not_overlapping = r.eval()
self.assertShapeEqual(
np.reshape(expected_input_backprop_not_overlapping, input_size), r)
self.assertAllClose(expected_input_backprop_not_overlapping,
input_backprop_not_overlapping)
# Test when overlapping is True
output_tensor = tf.constant(output_data_overlapping, shape=output_size)
r = gen_nn_ops._fractional_max_pool_grad(
input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True)
input_backprop_overlapping = r.eval()
self.assertShapeEqual(
np.reshape(expected_input_backprop_overlapping, input_size), r)
self.assertAllClose(expected_input_backprop_overlapping,
input_backprop_overlapping)
if __name__ == "__main__":
tf.test.main()

View File

@ -181,6 +181,8 @@ AvgPool
MaxPool
Softmax
LogSoftmax
FractionalAvgPoolGrad
FractionalMaxPoolGrad
# parsing_ops
ParseExample

View File

@ -132,6 +132,8 @@ to the `Convolution` section for details about the padding calculation.
@@max_pool_with_argmax
@@avg_pool3d
@@max_pool3d
@@fractional_avg_pool
@@fractional_max_pool
## Morphological filtering

View File

@ -361,6 +361,53 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
@ops.RegisterGradient("FractionalMaxPool")
def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
"""Returns gradient for FractionalMaxPool.
Since FractionalMaxPool has three outputs, there are three gradients passed in
for each of the outputs. Only the first one is useful, the other two gradients
are empty.
Args:
op: The FractionalMaxPoolOp.
grad_0: Gradient with respect to op.outputs[0]
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
Returns:
Input backprop for FractionalMaxPool op.
"""
# pylint: disable=protected-access
return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0],
grad_0, op.outputs[1],
op.outputs[2],
op.get_attr("overlapping"))
@ops.RegisterGradient("FractionalAvgPool")
def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
"""Returns gradient for FractionalAvgPool.
Since FractionalAvgPool has three outputs, there are three gradients passed in
for each of the outputs. Only the first one is useful, the other two gradients
are empty.
Args:
op: The FractionalAvgPoolOp.
grad_0: Gradient with respect to op.outputs[0]
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
Returns:
Input backprop for FractionalAvgPool op.
"""
# pylint: disable=protected-access
return gen_nn_ops._fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
op.outputs[1], op.outputs[2],
op.get_attr("overlapping"))
@ops.RegisterGradient("BatchNormWithGlobalNormalization")
def _BatchNormWithGlobalNormalizationGrad(op, grad):
"""Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.

View File

@ -941,6 +941,39 @@ def _AvgPoolGradShape(op):
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("FractionalMaxPool")
@ops.RegisterShape("FractionalAvgPool")
def _fractional_pool_shape(op):
input_dims = op.inputs[0].get_shape().with_rank(4).as_list()
pooling_ratio = op.get_attr("pooling_ratio")
output_dims = np.divide(input_dims, pooling_ratio).astype(int)
return [
# output.
tensor_shape.TensorShape(output_dims),
# row_pooling_sequence.
tensor_shape.TensorShape([output_dims[1]]),
# col_pooling_sequence.
tensor_shape.TensorShape([output_dims[2]])
]
@ops.RegisterShape("FractionalMaxPoolGrad")
def _fractional_max_pool_grad_shape(op):
"""Shape function for the FractionalMaxPoolGrad op."""
orig_input_shape = op.inputs[0].get_shape().with_rank(4)
return [orig_input_shape]
@ops.RegisterShape("FractionalAvgPoolGrad")
def _fractional_avg_pool_grad_shape(op):
"""Shape function for the FractionalAvgPoolGrad op."""
orig_input_shape = tensor_util.constant_value(op.inputs[0])
if orig_input_shape is not None:
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
else:
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("Conv2DBackpropFilter")
def _Conv2DBackpropFilterShape(op):
"""Shape function for the Conv2DBackpropFilter op."""