Adds fractional_max_pool and fractional_avg_pool ops. Fixes #2953.
Change: 131754627
This commit is contained in:
parent
79d8721bf2
commit
8b667b7d4b
@ -1448,6 +1448,9 @@ tf_kernel_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"avgpooling_op.cc",
|
"avgpooling_op.cc",
|
||||||
"cudnn_pooling_gpu.cc",
|
"cudnn_pooling_gpu.cc",
|
||||||
|
"fractional_avg_pool_op.cc",
|
||||||
|
"fractional_max_pool_op.cc",
|
||||||
|
"fractional_pool_common.cc",
|
||||||
"maxpooling_op.cc",
|
"maxpooling_op.cc",
|
||||||
"pooling_ops_3d.cc",
|
"pooling_ops_3d.cc",
|
||||||
"pooling_ops_common.cc",
|
"pooling_ops_common.cc",
|
||||||
@ -1455,6 +1458,7 @@ tf_kernel_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"avgpooling_op.h",
|
"avgpooling_op.h",
|
||||||
"cudnn_pooling_gpu.h",
|
"cudnn_pooling_gpu.h",
|
||||||
|
"fractional_pool_common.h",
|
||||||
"maxpooling_op.h",
|
"maxpooling_op.h",
|
||||||
"pooling_ops_common.h",
|
"pooling_ops_common.h",
|
||||||
],
|
],
|
||||||
|
354
tensorflow/core/kernels/fractional_avg_pool_op.cc
Normal file
354
tensorflow/core/kernels/fractional_avg_pool_op.cc
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/fractional_pool_common.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FractionalAvgPoolOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit FractionalAvgPoolOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
|
||||||
|
OP_REQUIRES(context, pooling_ratio_.size() == 4,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"pooling_ratio field must specify 4 dimensions"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
|
||||||
|
errors::Unimplemented("Fractional average pooling is not yet "
|
||||||
|
"supported on the batch nor channel dimension."));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
|
||||||
|
pooling_region_generated_ = false;
|
||||||
|
// Initialize philox random generator.
|
||||||
|
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
ConstEigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenMatrixMap;
|
||||||
|
|
||||||
|
constexpr int tensor_in_and_out_dims = 4;
|
||||||
|
|
||||||
|
const Tensor& tensor_in = context->input(0);
|
||||||
|
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
|
||||||
|
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
||||||
|
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
input_size_.push_back(tensor_in.dim_size(i));
|
||||||
|
}
|
||||||
|
// Output size.
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
output_size_.push_back(
|
||||||
|
static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
|
||||||
|
DCHECK_GT(output_size_[i], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate pooling sequence.
|
||||||
|
std::vector<int64> row_cum_seq;
|
||||||
|
std::vector<int64> col_cum_seq;
|
||||||
|
if (deterministic_) {
|
||||||
|
if (pooling_region_generated_) {
|
||||||
|
row_cum_seq = row_cum_seq_;
|
||||||
|
col_cum_seq = col_cum_seq_;
|
||||||
|
} else {
|
||||||
|
row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
row_cum_seq_ = row_cum_seq;
|
||||||
|
col_cum_seq_ = col_cum_seq;
|
||||||
|
pooling_region_generated_ = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare output.
|
||||||
|
Tensor* output_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(
|
||||||
|
0, TensorShape({output_size_[0], output_size_[1],
|
||||||
|
output_size_[2], output_size_[3]}),
|
||||||
|
&output_tensor));
|
||||||
|
Tensor* output_row_seq_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(
|
||||||
|
1, TensorShape({static_cast<int64>(row_cum_seq.size())}),
|
||||||
|
&output_row_seq_tensor));
|
||||||
|
Tensor* output_col_seq_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(
|
||||||
|
2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
|
||||||
|
&output_col_seq_tensor));
|
||||||
|
|
||||||
|
ConstEigenMatrixMap in_mat(
|
||||||
|
tensor_in.flat<T>().data(), input_size_[3],
|
||||||
|
input_size_[2] * input_size_[1] * input_size_[0]);
|
||||||
|
|
||||||
|
EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
|
||||||
|
output_size_[2] * output_size_[1] * output_size_[0]);
|
||||||
|
// out_count corresponds to number of elements in each pooling cell.
|
||||||
|
Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
|
||||||
|
|
||||||
|
// Initializes the output tensor and out_count with 0.
|
||||||
|
out_mat.setZero();
|
||||||
|
out_count.setZero();
|
||||||
|
|
||||||
|
auto output_row_seq_flat = output_row_seq_tensor->flat<int64>();
|
||||||
|
auto output_col_seq_flat = output_col_seq_tensor->flat<int64>();
|
||||||
|
|
||||||
|
// Set output tensors.
|
||||||
|
for (int i = 0; i < row_cum_seq.size(); ++i) {
|
||||||
|
output_row_seq_flat(i) = row_cum_seq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < col_cum_seq.size(); ++i) {
|
||||||
|
output_col_seq_flat(i) = col_cum_seq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// For both input and output,
|
||||||
|
// 0: batch
|
||||||
|
// 1: row / row
|
||||||
|
// 2: col / col
|
||||||
|
// 3: depth / channel
|
||||||
|
const int64 row_max = input_size_[1] - 1;
|
||||||
|
const int64 col_max = input_size_[2] - 1;
|
||||||
|
for (int64 b = 0; b < input_size_[0]; ++b) {
|
||||||
|
// row sequence.
|
||||||
|
for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
|
||||||
|
// row start and end.
|
||||||
|
const int64 row_start = row_cum_seq[hs];
|
||||||
|
int64 row_end =
|
||||||
|
overlapping_ ? row_cum_seq[hs + 1] : row_cum_seq[hs + 1] - 1;
|
||||||
|
row_end = std::min(row_end, row_max);
|
||||||
|
|
||||||
|
// col sequence.
|
||||||
|
for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
|
||||||
|
const int64 out_offset =
|
||||||
|
(b * output_size_[1] + hs) * output_size_[2] + ws;
|
||||||
|
// col start and end.
|
||||||
|
const int64 col_start = col_cum_seq[ws];
|
||||||
|
int64 col_end =
|
||||||
|
overlapping_ ? col_cum_seq[ws + 1] : col_cum_seq[ws + 1] - 1;
|
||||||
|
col_end = std::min(col_end, col_max);
|
||||||
|
for (int64 h = row_start; h <= row_end; ++h) {
|
||||||
|
for (int64 w = col_start; w <= col_end; ++w) {
|
||||||
|
const int64 in_offset =
|
||||||
|
(b * input_size_[1] + h) * input_size_[2] + w;
|
||||||
|
out_mat.col(out_offset) += in_mat.col(in_offset);
|
||||||
|
out_count(out_offset)++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DCHECK_GT(out_count.minCoeff(), 0);
|
||||||
|
out_mat.array().rowwise() /= out_count.transpose().array();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool deterministic_;
|
||||||
|
// meaningful only when deterministic_ is true.
|
||||||
|
mutex mu_;
|
||||||
|
std::vector<int64> row_cum_seq_;
|
||||||
|
std::vector<int64> col_cum_seq_;
|
||||||
|
bool pooling_region_generated_;
|
||||||
|
|
||||||
|
std::vector<int32> input_size_;
|
||||||
|
std::vector<int32> output_size_;
|
||||||
|
std::vector<float> pooling_ratio_;
|
||||||
|
bool pseudo_random_;
|
||||||
|
bool overlapping_;
|
||||||
|
GuardedPhiloxRandom generator_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_FRACTIONALAVGPOOL(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("FractionalAvgPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
FractionalAvgPoolOp<type>)
|
||||||
|
|
||||||
|
REGISTER_FRACTIONALAVGPOOL(int32);
|
||||||
|
REGISTER_FRACTIONALAVGPOOL(int64);
|
||||||
|
REGISTER_FRACTIONALAVGPOOL(float);
|
||||||
|
REGISTER_FRACTIONALAVGPOOL(double);
|
||||||
|
|
||||||
|
#undef REGISTER_FRACTIONALAVGPOOL
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class FractionalAvgPoolGradOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit FractionalAvgPoolGradOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Here's the basic idea:
|
||||||
|
// Batch and depth dimension are independent from row and col dimension. And
|
||||||
|
// because FractionalAvgPool currently only support pooling along row and
|
||||||
|
// col, we can basically think of this 4D tensor backpropagation as
|
||||||
|
// operation of a series of 2D planes.
|
||||||
|
//
|
||||||
|
// For each element of a 'slice' (2D plane) of output_backprop, we need to
|
||||||
|
// figure out its contributors when doing FractionalAvgPool operation. This
|
||||||
|
// can be done based on row_pooling_sequence, col_pooling_seq and
|
||||||
|
// overlapping.
|
||||||
|
// Once we figure out the original contributors, we just need to evenly
|
||||||
|
// divide the value of this element among these contributors.
|
||||||
|
//
|
||||||
|
// Internally, we divide the out_backprop tensor and store it in a temparary
|
||||||
|
// tensor of double type. And cast it to the corresponding type.
|
||||||
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
ConstEigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenDoubleMatrixMap;
|
||||||
|
|
||||||
|
// Grab the inputs.
|
||||||
|
const Tensor& orig_input_tensor_shape = context->input(0);
|
||||||
|
OP_REQUIRES(context, orig_input_tensor_shape.dims() == 1 &&
|
||||||
|
orig_input_tensor_shape.NumElements() == 4,
|
||||||
|
errors::InvalidArgument("original input tensor shape must be"
|
||||||
|
"1-dimensional and 4 elements"));
|
||||||
|
const Tensor& out_backprop = context->input(1);
|
||||||
|
const Tensor& row_seq_tensor = context->input(2);
|
||||||
|
const Tensor& col_seq_tensor = context->input(3);
|
||||||
|
|
||||||
|
const int64 out_batch = out_backprop.dim_size(0);
|
||||||
|
const int64 out_rows = out_backprop.dim_size(1);
|
||||||
|
const int64 out_cols = out_backprop.dim_size(2);
|
||||||
|
const int64 out_depth = out_backprop.dim_size(3);
|
||||||
|
|
||||||
|
auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
|
||||||
|
auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
|
||||||
|
auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
|
||||||
|
|
||||||
|
const int64 in_batch = orig_input_tensor_shape_flat(0);
|
||||||
|
const int64 in_rows = orig_input_tensor_shape_flat(1);
|
||||||
|
const int64 in_cols = orig_input_tensor_shape_flat(2);
|
||||||
|
const int64 in_depth = orig_input_tensor_shape_flat(3);
|
||||||
|
|
||||||
|
constexpr int tensor_in_and_out_dims = 4;
|
||||||
|
// Transform orig_input_tensor_shape into TensorShape
|
||||||
|
TensorShape in_shape;
|
||||||
|
for (auto i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
in_shape.AddDim(orig_input_tensor_shape_flat(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create intermediate in_backprop.
|
||||||
|
Tensor in_backprop_tensor_temp;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<double>::v(), in_shape,
|
||||||
|
&in_backprop_tensor_temp));
|
||||||
|
in_backprop_tensor_temp.flat<double>().setZero();
|
||||||
|
// Transform 4D tensor to 2D matrix.
|
||||||
|
EigenDoubleMatrixMap in_backprop_tensor_temp_mat(
|
||||||
|
in_backprop_tensor_temp.flat<double>().data(), in_depth,
|
||||||
|
in_cols * in_rows * in_batch);
|
||||||
|
ConstEigenMatrixMap out_backprop_mat(out_backprop.flat<T>().data(),
|
||||||
|
out_depth,
|
||||||
|
out_cols * out_rows * out_batch);
|
||||||
|
// Loop through each element of out_backprop and evenly distribute the
|
||||||
|
// element to the corresponding pooling cell.
|
||||||
|
const int64 in_max_row_index = in_rows - 1;
|
||||||
|
const int64 in_max_col_index = in_cols - 1;
|
||||||
|
for (int64 b = 0; b < out_batch; ++b) {
|
||||||
|
for (int64 r = 0; r < out_rows; ++r) {
|
||||||
|
const int64 in_row_start = row_seq_tensor_flat(r);
|
||||||
|
int64 in_row_end = overlapping_ ? row_seq_tensor_flat(r + 1)
|
||||||
|
: row_seq_tensor_flat(r + 1) - 1;
|
||||||
|
in_row_end = std::min(in_row_end, in_max_row_index);
|
||||||
|
for (int64 c = 0; c < out_cols; ++c) {
|
||||||
|
const int64 in_col_start = col_seq_tensor_flat(c);
|
||||||
|
int64 in_col_end = overlapping_ ? col_seq_tensor_flat(c + 1)
|
||||||
|
: col_seq_tensor_flat(c + 1) - 1;
|
||||||
|
in_col_end = std::min(in_col_end, in_max_col_index);
|
||||||
|
|
||||||
|
const int64 num_elements_in_pooling_cell =
|
||||||
|
(in_row_end - in_row_start + 1) * (in_col_end - in_col_start + 1);
|
||||||
|
const int64 out_index = (b * out_rows + r) * out_cols + c;
|
||||||
|
// Now we can evenly distribute out_backprop(b, h, w, *) to
|
||||||
|
// in_backprop(b, hs:he, ws:we, *).
|
||||||
|
for (int64 in_r = in_row_start; in_r <= in_row_end; ++in_r) {
|
||||||
|
for (int64 in_c = in_col_start; in_c <= in_col_end; ++in_c) {
|
||||||
|
const int64 in_index = (b * in_rows + in_r) * in_cols + in_c;
|
||||||
|
// Walk through each channel (depth).
|
||||||
|
for (int64 d = 0; d < out_depth; ++d) {
|
||||||
|
const double out_backprop_element = static_cast<double>(
|
||||||
|
out_backprop_mat.coeffRef(d, out_index));
|
||||||
|
double& in_backprop_ref =
|
||||||
|
in_backprop_tensor_temp_mat.coeffRef(d, in_index);
|
||||||
|
in_backprop_ref +=
|
||||||
|
out_backprop_element / num_elements_in_pooling_cell;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Depending on the type, cast double to type T.
|
||||||
|
Tensor* in_backprop_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, in_shape, &in_backprop_tensor));
|
||||||
|
auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>();
|
||||||
|
auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>();
|
||||||
|
for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) {
|
||||||
|
in_backprop_tensor_flat(i) =
|
||||||
|
static_cast<T>(in_backprop_tensor_temp_flat(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool overlapping_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_FRACTIONALAVGPOOLGRAD(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("FractionalAvgPoolGrad") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<type>("T"), \
|
||||||
|
FractionalAvgPoolGradOp<type>)
|
||||||
|
|
||||||
|
REGISTER_FRACTIONALAVGPOOLGRAD(int32);
|
||||||
|
REGISTER_FRACTIONALAVGPOOLGRAD(int64);
|
||||||
|
REGISTER_FRACTIONALAVGPOOLGRAD(float);
|
||||||
|
REGISTER_FRACTIONALAVGPOOLGRAD(double);
|
||||||
|
|
||||||
|
#undef REGISTER_FRACTIONALAVGPOOLGRAD
|
||||||
|
} // namespace tensorflow
|
381
tensorflow/core/kernels/fractional_max_pool_op.cc
Normal file
381
tensorflow/core/kernels/fractional_max_pool_op.cc
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/fractional_pool_common.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FractionalMaxPoolOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit FractionalMaxPoolOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
|
||||||
|
|
||||||
|
OP_REQUIRES(context, pooling_ratio_.size() == 4,
|
||||||
|
errors::InvalidArgument("pooling_ratio field must "
|
||||||
|
"specify 4 dimensions"));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
|
||||||
|
errors::Unimplemented("Fractional max pooling is not yet "
|
||||||
|
"supported on the batch nor channel dimension."));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
|
||||||
|
pooling_region_generated_ = false;
|
||||||
|
// Initialize philox random generator.
|
||||||
|
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
ConstEigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenMatrixMap;
|
||||||
|
|
||||||
|
constexpr int tensor_in_and_out_dims = 4;
|
||||||
|
|
||||||
|
const Tensor& tensor_in = context->input(0);
|
||||||
|
OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
|
||||||
|
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
||||||
|
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
input_size_.push_back(tensor_in.dim_size(i));
|
||||||
|
}
|
||||||
|
// Output size.
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
output_size_.push_back(
|
||||||
|
static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
|
||||||
|
DCHECK_GT(output_size_[i], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate pooling sequence.
|
||||||
|
std::vector<int64> height_cum_seq;
|
||||||
|
std::vector<int64> width_cum_seq;
|
||||||
|
if (deterministic_) {
|
||||||
|
if (pooling_region_generated_) {
|
||||||
|
height_cum_seq = height_cum_seq_;
|
||||||
|
width_cum_seq = width_cum_seq_;
|
||||||
|
} else {
|
||||||
|
height_cum_seq = GeneratePoolingSequence(
|
||||||
|
input_size_[1], output_size_[1], &generator_, pseudo_random_);
|
||||||
|
width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
height_cum_seq_ = height_cum_seq;
|
||||||
|
width_cum_seq_ = width_cum_seq;
|
||||||
|
pooling_region_generated_ = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
|
||||||
|
&generator_, pseudo_random_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare output.
|
||||||
|
Tensor* output_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(
|
||||||
|
0, TensorShape({output_size_[0], output_size_[1],
|
||||||
|
output_size_[2], output_size_[3]}),
|
||||||
|
&output_tensor));
|
||||||
|
Tensor* output_height_seq_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context,
|
||||||
|
context->allocate_output(
|
||||||
|
1, TensorShape({static_cast<int64>(height_cum_seq.size())}),
|
||||||
|
&output_height_seq_tensor));
|
||||||
|
Tensor* output_width_seq_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->allocate_output(
|
||||||
|
2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
|
||||||
|
&output_width_seq_tensor));
|
||||||
|
|
||||||
|
ConstEigenMatrixMap in_mat(
|
||||||
|
tensor_in.flat<T>().data(), input_size_[3],
|
||||||
|
input_size_[2] * input_size_[1] * input_size_[0]);
|
||||||
|
|
||||||
|
EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
|
||||||
|
output_size_[2] * output_size_[1] * output_size_[0]);
|
||||||
|
|
||||||
|
// Initializes the output tensor with MIN<T>.
|
||||||
|
output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
|
||||||
|
|
||||||
|
auto output_height_seq_flat = output_height_seq_tensor->flat<int64>();
|
||||||
|
auto output_width_seq_flat = output_width_seq_tensor->flat<int64>();
|
||||||
|
|
||||||
|
// Set output tensors.
|
||||||
|
for (int i = 0; i < height_cum_seq.size(); ++i) {
|
||||||
|
output_height_seq_flat(i) = height_cum_seq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < width_cum_seq.size(); ++i) {
|
||||||
|
output_width_seq_flat(i) = width_cum_seq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// For both input and output,
|
||||||
|
// 0: batch
|
||||||
|
// 1: height / row
|
||||||
|
// 2: width / col
|
||||||
|
// 3: depth / channel
|
||||||
|
const int64 height_max = input_size_[1] - 1;
|
||||||
|
const int64 width_max = input_size_[2] - 1;
|
||||||
|
for (int64 b = 0; b < input_size_[0]; ++b) {
|
||||||
|
// height sequence.
|
||||||
|
for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
|
||||||
|
// height start and end.
|
||||||
|
const int64 height_start = height_cum_seq[hs];
|
||||||
|
int64 height_end =
|
||||||
|
overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
|
||||||
|
height_end = std::min(height_end, height_max);
|
||||||
|
|
||||||
|
// width sequence.
|
||||||
|
for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
|
||||||
|
const int64 out_offset =
|
||||||
|
(b * output_size_[1] + hs) * output_size_[2] + ws;
|
||||||
|
// width start and end.
|
||||||
|
const int64 width_start = width_cum_seq[ws];
|
||||||
|
int64 width_end =
|
||||||
|
overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
|
||||||
|
width_end = std::min(width_end, width_max);
|
||||||
|
for (int64 h = height_start; h <= height_end; ++h) {
|
||||||
|
for (int64 w = width_start; w <= width_end; ++w) {
|
||||||
|
const int64 in_offset =
|
||||||
|
(b * input_size_[1] + h) * input_size_[2] + w;
|
||||||
|
out_mat.col(out_offset) =
|
||||||
|
out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool deterministic_;
|
||||||
|
// meaningful only when deterministic_ is true.
|
||||||
|
mutex mu_;
|
||||||
|
std::vector<int64> height_cum_seq_;
|
||||||
|
std::vector<int64> width_cum_seq_;
|
||||||
|
bool pooling_region_generated_;
|
||||||
|
|
||||||
|
std::vector<int32> input_size_;
|
||||||
|
std::vector<int32> output_size_;
|
||||||
|
std::vector<float> pooling_ratio_;
|
||||||
|
bool pseudo_random_;
|
||||||
|
bool overlapping_;
|
||||||
|
GuardedPhiloxRandom generator_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_FRACTIONALMAXPOOL(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
FractionalMaxPoolOp<type>)
|
||||||
|
|
||||||
|
REGISTER_FRACTIONALMAXPOOL(int32);
|
||||||
|
REGISTER_FRACTIONALMAXPOOL(int64);
|
||||||
|
REGISTER_FRACTIONALMAXPOOL(float);
|
||||||
|
REGISTER_FRACTIONALMAXPOOL(double);
|
||||||
|
|
||||||
|
#undef REGISTER_FRACTIONALMAXPOOL
|
||||||
|
|
||||||
|
static const int kInvalidMaxPoolingIndex = -1;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class FractionalMaxPoolGradOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// There are two steps when calculating gradient for FractionalMaxPool.
|
||||||
|
// 1) Walk through the process of calculating fractional pooling given
|
||||||
|
// pooling region; however, in the process, keep track of where the max
|
||||||
|
// element comes from. (arg_max)
|
||||||
|
// 2) Populate the value of out_backprop to where arg_max indicates. If
|
||||||
|
// we support overlapping, it is likely to have multiple out_backprop[i]
|
||||||
|
// propagates back to the same arg_max value.
|
||||||
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
ConstEigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenMatrixMap;
|
||||||
|
typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
EigenIndexMatrixMap;
|
||||||
|
|
||||||
|
const Tensor& tensor_in = context->input(0);
|
||||||
|
const Tensor& tensor_out = context->input(1);
|
||||||
|
const Tensor& out_backprop = context->input(2);
|
||||||
|
const Tensor& height_seq_tensor = context->input(3);
|
||||||
|
const Tensor& width_seq_tensor = context->input(4);
|
||||||
|
|
||||||
|
// Just to make it similar to FractionalMaxPoolOp.
|
||||||
|
constexpr int tensor_in_and_out_dims = 4;
|
||||||
|
std::vector<int64> input_size;
|
||||||
|
std::vector<int64> output_size;
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
input_size.push_back(tensor_in.dim_size(i));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
|
||||||
|
output_size.push_back(tensor_out.dim_size(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------
|
||||||
|
// Step 1
|
||||||
|
// ---------
|
||||||
|
Tensor tensor_out_dup;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||||
|
tensor_out.shape(), &tensor_out_dup));
|
||||||
|
Tensor tensor_out_arg_max;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
|
||||||
|
tensor_out.shape(),
|
||||||
|
&tensor_out_arg_max));
|
||||||
|
// Find arg_max for each tensor_out
|
||||||
|
ConstEigenMatrixMap tensor_in_mat(
|
||||||
|
tensor_in.flat<T>().data(), input_size[3],
|
||||||
|
input_size[2] * input_size[1] * input_size[0]);
|
||||||
|
EigenMatrixMap tensor_out_dup_mat(
|
||||||
|
tensor_out_dup.flat<T>().data(), output_size[3],
|
||||||
|
output_size[2] * output_size[1] * output_size[0]);
|
||||||
|
EigenIndexMatrixMap tensor_out_arg_max_mat(
|
||||||
|
tensor_out_arg_max.flat<int64>().data(), output_size[3],
|
||||||
|
output_size[2] * output_size[1] * output_size[0]);
|
||||||
|
|
||||||
|
tensor_out_arg_max.flat<int64>().setConstant(kInvalidMaxPoolingIndex);
|
||||||
|
// Initializes the duplicate output tensor with MIN<T>.
|
||||||
|
tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
|
||||||
|
|
||||||
|
auto height_seq_tensor_flat = height_seq_tensor.flat<int64>();
|
||||||
|
auto width_seq_tensor_flat = width_seq_tensor.flat<int64>();
|
||||||
|
|
||||||
|
// Now walk through the process of fractional max pooling again.
|
||||||
|
// For both input and output,
|
||||||
|
// 0: batch
|
||||||
|
// 1: height / row
|
||||||
|
// 2: width / col
|
||||||
|
// 3: depth / channel
|
||||||
|
const int64 height_max = input_size[1] - 1;
|
||||||
|
const int64 width_max = input_size[2] - 1;
|
||||||
|
for (int64 b = 0; b < input_size[0]; ++b) {
|
||||||
|
// height sequence.
|
||||||
|
for (int64 hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
|
||||||
|
// height start and end.
|
||||||
|
const int64 height_start = height_seq_tensor_flat(hs);
|
||||||
|
int64 height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
|
||||||
|
: height_seq_tensor_flat(hs + 1) - 1;
|
||||||
|
height_end = std::min(height_end, height_max);
|
||||||
|
|
||||||
|
// width sequence.
|
||||||
|
for (int64 ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
|
||||||
|
const int64 out_index =
|
||||||
|
(b * output_size[1] + hs) * output_size[2] + ws;
|
||||||
|
// width start and end.
|
||||||
|
const int64 width_start = width_seq_tensor_flat(ws);
|
||||||
|
int64 width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
|
||||||
|
: width_seq_tensor_flat(ws + 1) - 1;
|
||||||
|
width_end = std::min(width_end, width_max);
|
||||||
|
for (int64 h = height_start; h <= height_end; ++h) {
|
||||||
|
for (int64 w = width_start; w <= width_end; ++w) {
|
||||||
|
const int64 in_index =
|
||||||
|
(b * input_size[1] + h) * input_size[2] + w;
|
||||||
|
// Walk through each channel (depth).
|
||||||
|
for (int64 d = 0; d < input_size[3]; ++d) {
|
||||||
|
const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
|
||||||
|
T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
|
||||||
|
int64& out_arg_max_ref =
|
||||||
|
tensor_out_arg_max_mat.coeffRef(d, out_index);
|
||||||
|
if (output_ref < input_ref ||
|
||||||
|
out_arg_max_ref == kInvalidMaxPoolingIndex) {
|
||||||
|
output_ref = input_ref;
|
||||||
|
int input_offset = in_index * input_size[3] + d;
|
||||||
|
out_arg_max_ref = input_offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check tensor_out_dup is the same as tensor_out.
|
||||||
|
ConstEigenMatrixMap tensor_out_mat(
|
||||||
|
tensor_out.flat<T>().data(), output_size[3],
|
||||||
|
output_size[2] * output_size[1] * output_size[0]);
|
||||||
|
const int64 num_reshaped_cols =
|
||||||
|
output_size[2] * output_size[1] * output_size[0];
|
||||||
|
for (int64 i = 0; i < num_reshaped_cols; ++i) {
|
||||||
|
for (int64 j = 0; j < output_size[3]; ++j) {
|
||||||
|
DCHECK_EQ(tensor_out_dup_mat(j, i), tensor_out_mat(j, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, tensor_in.shape(), &output));
|
||||||
|
output->flat<T>().setZero();
|
||||||
|
|
||||||
|
auto out_backprop_flat = out_backprop.flat<T>();
|
||||||
|
auto input_backprop_flat = output->flat<T>();
|
||||||
|
auto out_arg_max_flat = tensor_out_arg_max.flat<int64>();
|
||||||
|
int num_total_outputs = out_backprop_flat.size();
|
||||||
|
int num_total_inputs = input_backprop_flat.size();
|
||||||
|
|
||||||
|
for (int index = 0; index < num_total_outputs; ++index) {
|
||||||
|
int input_backprop_index = out_arg_max_flat(index);
|
||||||
|
// According to maxpooling_op.cc, the performance impact below is small.
|
||||||
|
CHECK(input_backprop_index >= 0 &&
|
||||||
|
input_backprop_index < num_total_inputs)
|
||||||
|
<< "Invalid input backprop index: " << input_backprop_index << ", "
|
||||||
|
<< num_total_inputs;
|
||||||
|
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool overlapping_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_FRACTIONALMAXPOOLGRAD(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<type>("T"), \
|
||||||
|
FractionalMaxPoolGradOp<type>)
|
||||||
|
|
||||||
|
REGISTER_FRACTIONALMAXPOOLGRAD(int32);
|
||||||
|
REGISTER_FRACTIONALMAXPOOLGRAD(int64);
|
||||||
|
REGISTER_FRACTIONALMAXPOOLGRAD(float);
|
||||||
|
REGISTER_FRACTIONALMAXPOOLGRAD(double);
|
||||||
|
|
||||||
|
#undef REGISTER_FRACTIONALMAXPOOLGRAD
|
||||||
|
} // namespace tensorflow
|
134
tensorflow/core/kernels/fractional_pool_common.cc
Normal file
134
tensorflow/core/kernels/fractional_pool_common.cc
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/fractional_pool_common.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
static std::vector<int64> GeneratePoolingSequencePseudoRandom(
|
||||||
|
int input_length, int output_length, GuardedPhiloxRandom* generator) {
|
||||||
|
std::vector<int64> cum_seq(output_length + 1, 0);
|
||||||
|
std::vector<int64> diff(output_length, 0);
|
||||||
|
|
||||||
|
double alpha = static_cast<double>(input_length) / output_length;
|
||||||
|
int k = input_length / output_length;
|
||||||
|
|
||||||
|
// In the paper [1], author proposes the following procedure to generate a
|
||||||
|
// pseudo random pooling region:
|
||||||
|
// 1) Set a_0 = 1, a_Nout = Nin;
|
||||||
|
// 2) a_i = ceil(alpha*(u+i))
|
||||||
|
// in which, i = 1, 2, ... , Nout-1
|
||||||
|
// u is a random number in (0,1) for all i
|
||||||
|
// alpha = Nin/Nout in (1,2)
|
||||||
|
// The sequence {a_i} should satisfy a_i-a_{i-1} = 1 or 2
|
||||||
|
// Note: for step 1), it makes more sense to make a_Nout = Nin+1, that way,
|
||||||
|
// a_i-a_{i-1} = 1 or 2 is also true for i = Nout.
|
||||||
|
//
|
||||||
|
// However, there are choices of alpha and u that will make
|
||||||
|
// a_i - a_{i-1} > 2. This happens at the left boundary. For example, with
|
||||||
|
// alpha = 1.732, u = 0.8, then a_1 = 4, a_1-a_0 = 3.
|
||||||
|
// This is why u_max1 is needed, i.e. u is a random number in (0,u_max1)
|
||||||
|
// instead of (0,1).
|
||||||
|
// Define k = ceil(alpha)-1, then we require:
|
||||||
|
// a_1 = alpha*(u+1) <= a_0+(k+1)
|
||||||
|
// ===> This gives u_max1 = (k+2)/alpha - 1.
|
||||||
|
//
|
||||||
|
// In addition, when extending the pooling sequence generation process for
|
||||||
|
// alpha beyond (1,2), e.g. (k,k+1); a check on the right boundary is also
|
||||||
|
// needed to make sure the last gap a_Nout-a_{Nout-1} >= k. Solving it gives
|
||||||
|
// u_max2 = (Nin+1-k)/alpha - (Nout-1)
|
||||||
|
// Here is an example where u > u_max2, alpha = 2.3, u = 0.7, u_max2 = 0.565,
|
||||||
|
// Nin = 23, Nout = 10; the sequence
|
||||||
|
// from a_0 to a_10 is:
|
||||||
|
// [1, 4, 7, 9, 11, 14, 16, 18, 21, 23, 24]
|
||||||
|
// The last gap is only 1.
|
||||||
|
//
|
||||||
|
// [1]: https://arxiv.org/abs/1412.6071
|
||||||
|
double u_max1 = (k + 2) / alpha - 1;
|
||||||
|
double u_max2 = (input_length + 1 - k) / alpha - (output_length - 1);
|
||||||
|
double max_u = std::min(u_max1, u_max2);
|
||||||
|
|
||||||
|
// Generate random number in parallel.
|
||||||
|
auto local_gen = generator->ReserveSamples32(2);
|
||||||
|
random::SimplePhilox random(&local_gen);
|
||||||
|
const double u = random.RandDouble() * max_u;
|
||||||
|
|
||||||
|
cum_seq[0] = 1;
|
||||||
|
cum_seq[output_length] = input_length + 1;
|
||||||
|
for (int i = 1; i < output_length; ++i) {
|
||||||
|
cum_seq[i] = static_cast<int>(ceil(alpha * (i + u)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < output_length; ++i) {
|
||||||
|
diff[i] = cum_seq[i + 1] - cum_seq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<int64> GeneratePoolingSequenceRandom(
|
||||||
|
int input_length, int output_length, GuardedPhiloxRandom* generator) {
|
||||||
|
int k = input_length / output_length;
|
||||||
|
int num_random_spot = input_length % output_length;
|
||||||
|
std::vector<int64> diff(output_length, k);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_random_spot; ++i) {
|
||||||
|
diff[i] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Randomly shuffle this vector.
|
||||||
|
auto local_gen = generator->ReserveSamples32(diff.size());
|
||||||
|
random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen);
|
||||||
|
const auto uniform = [&single](uint32 n) { return single() % n; };
|
||||||
|
RandomShuffle(diff.begin(), diff.end(), uniform);
|
||||||
|
|
||||||
|
return diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
|
||||||
|
GuardedPhiloxRandom* generator,
|
||||||
|
bool pseudo_random) {
|
||||||
|
std::vector<int64> diff;
|
||||||
|
// This is a case that regular pooling can handle, just return diff with
|
||||||
|
// each element input_length/output_length.
|
||||||
|
if (input_length % output_length == 0) {
|
||||||
|
diff = std::vector<int64>(output_length, input_length / output_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pseudo_random) {
|
||||||
|
diff = GeneratePoolingSequencePseudoRandom(input_length, output_length,
|
||||||
|
generator);
|
||||||
|
} else {
|
||||||
|
diff =
|
||||||
|
GeneratePoolingSequenceRandom(input_length, output_length, generator);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity check.
|
||||||
|
int k = input_length / output_length;
|
||||||
|
for (int i = 0; i < output_length; ++i) {
|
||||||
|
// k<= diff[i] <= k+1.
|
||||||
|
DCHECK_GE(diff[i], k);
|
||||||
|
DCHECK_LE(diff[i], k + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return cumulative sequence.
|
||||||
|
std::vector<int64> cum_seq(output_length + 1, 0);
|
||||||
|
for (int i = 1; i < cum_seq.size(); ++i) {
|
||||||
|
cum_seq[i] = cum_seq[i - 1] + diff[i - 1];
|
||||||
|
}
|
||||||
|
return cum_seq;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
78
tensorflow/core/kernels/fractional_pool_common.h
Normal file
78
tensorflow/core/kernels/fractional_pool_common.h
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
|
||||||
|
#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Shuffle a container randomly, copied from random_shuffle_op.cc
|
||||||
|
template <class Iter, class Random>
|
||||||
|
static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) {
|
||||||
|
if (first == last) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto stop = last - 1;
|
||||||
|
for (auto i = first; i != stop; ++i) {
|
||||||
|
using std::iter_swap;
|
||||||
|
iter_swap(i, i + uniform(last - i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate pooling sequence for fractional pooling along one dimension.
|
||||||
|
//
|
||||||
|
// Regular max/avg pooling can be viewed as a special case, in which given the
|
||||||
|
// * input_length: e.g. 10
|
||||||
|
// * output_length: e.g. 5
|
||||||
|
// it will generate pooling sequence as
|
||||||
|
// diff sequence: [2, 2, 2, 2, 2]
|
||||||
|
// or as
|
||||||
|
// cumulative sequence: [0, 2, 4, 6, 8, 10]
|
||||||
|
//
|
||||||
|
// In the case of fractional pooling, input_length is not an integer multiple of
|
||||||
|
// output_length, randomness plays a role when generating pooling sequence.
|
||||||
|
// There are two type of randomness (random vs pseudo-random) defined in paper:
|
||||||
|
// http://arxiv.org/abs/1412.6071
|
||||||
|
// You can check the paper for the difference between these two types.
|
||||||
|
//
|
||||||
|
// In summary, the generated diff sequence satisfy the following properties for
|
||||||
|
// both types of randomness:
|
||||||
|
// * length(generated_diff_pooling_sequence) = output_length
|
||||||
|
// * sum(generated_diff_pooling_sequence) = input_length
|
||||||
|
// * Let's define floor(input_length / output_length) = K, then
|
||||||
|
// K <= generated_diff_pooling_sequence[i] <= K+1
|
||||||
|
// For example, when input_length = 10, output_length = 6, the followings are
|
||||||
|
// valid pooling sequence:
|
||||||
|
// * [1, 2, 2, 1, 2, 2]
|
||||||
|
// * [1, 1, 2, 2, 2, 2]
|
||||||
|
// [1, 3, 2, 2, 2, 2] is not valid.
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// input_length: See above explanation
|
||||||
|
// output_length: See above explanation
|
||||||
|
// generator: Parallel version of random number generator
|
||||||
|
// pseudo_random: Whether or not use pseudo-random
|
||||||
|
// Returns:
|
||||||
|
// pooling_sequence: This is the cumulative pooling sequence.
|
||||||
|
std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
|
||||||
|
GuardedPhiloxRandom* generator,
|
||||||
|
bool pseudo_random);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
|
@ -1515,4 +1515,205 @@ values: The `k` largest elements along each last dimensional slice.
|
|||||||
indices: The indices of `values` within the last dimension of `input`.
|
indices: The indices of `values` within the last dimension of `input`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
REGISTER_OP("FractionalMaxPool")
|
||||||
|
.Input("value: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Output("row_pooling_sequence: int64")
|
||||||
|
.Output("col_pooling_sequence: int64")
|
||||||
|
.Attr("pooling_ratio: list(float) >=4")
|
||||||
|
.Attr("pseudo_random: bool = false")
|
||||||
|
.Attr("overlapping: bool = false")
|
||||||
|
.Attr("deterministic: bool = false")
|
||||||
|
.Attr("seed: int = 0")
|
||||||
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("T: {float, double, int32, int64}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Performs fractional max pooling on the input.
|
||||||
|
|
||||||
|
Fractional max pooling is slightly different than regular max pooling. In
|
||||||
|
regular max pooling, you downsize an input set by taking the maximum value of
|
||||||
|
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
|
||||||
|
a factor of N, where N is an integer. Fractional max pooling, as you might
|
||||||
|
expect from the word "fractional", means that the overall reduction ratio N
|
||||||
|
does not have to be an integer.
|
||||||
|
|
||||||
|
The sizes of the pooling regions are generated randomly but are fairly uniform.
|
||||||
|
For example, let's look at the height dimension, and the constraints on the
|
||||||
|
list of rows that will be pool boundaries.
|
||||||
|
|
||||||
|
First we define the following:
|
||||||
|
|
||||||
|
1. input_row_length : the number of rows from the input set
|
||||||
|
2. output_row_length : which will be smaller than the input
|
||||||
|
3. alpha = input_row_length / output_row_length : our reduction ratio
|
||||||
|
4. K = floor(alpha)
|
||||||
|
5. row_pooling_sequence : this is the result list of pool boundary rows
|
||||||
|
|
||||||
|
Then, row_pooling_sequence should satisfy:
|
||||||
|
|
||||||
|
1. a[0] = 0 : the first value of the sequence is 0
|
||||||
|
2. a[end] = input_row_length : the last value of the sequence is the size
|
||||||
|
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
|
||||||
|
4. length(row_pooling_sequence) = output_row_length+1
|
||||||
|
|
||||||
|
For more details on fractional max pooling, see this paper:
|
||||||
|
[Benjamin Graham, Fractional Max-Pooling]
|
||||||
|
(http://arxiv.org/abs/1412.6071)
|
||||||
|
|
||||||
|
value: 4-D with shape `[batch, height, width, channels]`.
|
||||||
|
pooling_ratio: Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
pseudo_random: When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
overlapping: When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [20, 16] for fractional max pooling.
|
||||||
|
deterministic: When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalMaxPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalMaxPool deterministic.
|
||||||
|
seed: If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
seed2: An second seed to avoid seed collision.
|
||||||
|
output: output tensor after fractional max pooling.
|
||||||
|
row_pooling_sequence: row pooling sequence, needed to calculate gradient.
|
||||||
|
col_pooling_sequence: column pooling sequence, needed to calculate gradient.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("FractionalMaxPoolGrad")
|
||||||
|
.Input("orig_input: T")
|
||||||
|
.Input("orig_output: T")
|
||||||
|
.Input("out_backprop: T")
|
||||||
|
.Input("row_pooling_sequence: int64")
|
||||||
|
.Input("col_pooling_sequence: int64")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("overlapping: bool = false")
|
||||||
|
.Attr("T: {float, double, int32, int64}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes gradient of the FractionalMaxPool function.
|
||||||
|
|
||||||
|
orig_input: Original input for `fractional_max_pool`
|
||||||
|
orig_output: Original output for `fractional_max_pool`
|
||||||
|
out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients
|
||||||
|
w.r.t. the output of `fractional_max_pool`.
|
||||||
|
row_pooling_sequence: row pooling sequence, form pooling region with
|
||||||
|
col_pooling_sequence.
|
||||||
|
col_pooling_sequence: column pooling sequence, form pooling region with
|
||||||
|
row_pooling sequence.
|
||||||
|
overlapping: When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [20, 16] for fractional max pooling.
|
||||||
|
output: 4-D. Gradients w.r.t. the input of `fractional_max_pool`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
REGISTER_OP("FractionalAvgPool")
|
||||||
|
.Input("value: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Output("row_pooling_sequence: int64")
|
||||||
|
.Output("col_pooling_sequence: int64")
|
||||||
|
.Attr("pooling_ratio: list(float) >=4")
|
||||||
|
.Attr("pseudo_random: bool = false")
|
||||||
|
.Attr("overlapping: bool = false")
|
||||||
|
.Attr("deterministic: bool = false")
|
||||||
|
.Attr("seed: int = 0")
|
||||||
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("T: {float, double, int32, int64}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Performs fractional average pooling on the input.
|
||||||
|
|
||||||
|
Fractional average pooling is similar to Fractional max pooling in the pooling
|
||||||
|
region generation step. The only difference is that after pooling regions are
|
||||||
|
generated, a mean operation is performed instead of a max operation in each
|
||||||
|
pooling region.
|
||||||
|
|
||||||
|
value: 4-D with shape `[batch, height, width, channels]`.
|
||||||
|
pooling_ratio: Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
pseudo_random: When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
overlapping: When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [41/3, 26/3] for fractional avg pooling.
|
||||||
|
deterministic: When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalAvgPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalAvgPool deterministic.
|
||||||
|
seed: If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
seed2: An second seed to avoid seed collision.
|
||||||
|
output: output tensor after fractional avg pooling.
|
||||||
|
row_pooling_sequence: row pooling sequence, needed to calculate gradient.
|
||||||
|
col_pooling_sequence: column pooling sequence, needed to calculate gradient.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("FractionalAvgPoolGrad")
|
||||||
|
.Input("orig_input_tensor_shape: int64")
|
||||||
|
.Input("out_backprop: T")
|
||||||
|
.Input("row_pooling_sequence: int64")
|
||||||
|
.Input("col_pooling_sequence: int64")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("overlapping: bool = false")
|
||||||
|
.Attr("T: {float, double, int32, int64}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes gradient of the FractionalAvgPool function.
|
||||||
|
|
||||||
|
Unlike FractionalMaxPoolGrad, we don't need to find arg_max for
|
||||||
|
FractionalAvgPoolGrad, we just need to evenly back-propagate each element of
|
||||||
|
out_backprop to those indices that form the same pooling cell. Therefore, we
|
||||||
|
just need to know the shape of original input tensor, instead of the whole
|
||||||
|
tensor.
|
||||||
|
|
||||||
|
orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool`
|
||||||
|
out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients
|
||||||
|
w.r.t. the output of `fractional_avg_pool`.
|
||||||
|
row_pooling_sequence: row pooling sequence, form pooling region with
|
||||||
|
col_pooling_sequence.
|
||||||
|
col_pooling_sequence: column pooling sequence, form pooling region with
|
||||||
|
row_pooling sequence.
|
||||||
|
overlapping: When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [41/3, 26/3] for fractional avg pooling.
|
||||||
|
output: 4-D. Gradients w.r.t. the input of `fractional_avg_pool`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
|
||||||
|
|
||||||
|
Performs fractional max pooling on the input.
|
||||||
|
|
||||||
|
Fractional max pooling is slightly different than regular max pooling. In
|
||||||
|
regular max pooling, you downsize an input set by taking the maximum value of
|
||||||
|
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
|
||||||
|
a factor of N, where N is an integer. Fractional max pooling, as you might
|
||||||
|
expect from the word "fractional", means that the overall reduction ratio N
|
||||||
|
does not have to be an integer.
|
||||||
|
|
||||||
|
The sizes of the pooling regions are generated randomly but are fairly uniform.
|
||||||
|
For example, let's look at the height dimension, and the constraints on the
|
||||||
|
list of rows that will be pool boundaries.
|
||||||
|
|
||||||
|
First we define the following:
|
||||||
|
|
||||||
|
1. input_row_length : the number of rows from the input set
|
||||||
|
2. output_row_length : which will be smaller than the input
|
||||||
|
3. alpha = input_row_length / output_row_length : our reduction ratio
|
||||||
|
4. K = floor(alpha)
|
||||||
|
5. row_pooling_sequence : this is the result list of pool boundary rows
|
||||||
|
|
||||||
|
Then, row_pooling_sequence should satisfy:
|
||||||
|
|
||||||
|
1. a[0] = 0 : the first value of the sequence is 0
|
||||||
|
2. a[end] = input_row_length : the last value of the sequence is the size
|
||||||
|
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
|
||||||
|
4. length(row_pooling_sequence) = output_row_length+1
|
||||||
|
|
||||||
|
For more details on fractional max pooling, see this paper:
|
||||||
|
[Benjamin Graham, Fractional Max-Pooling]
|
||||||
|
(http://arxiv.org/abs/1412.6071)
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
|
||||||
|
4-D with shape `[batch, height, width, channels]`.
|
||||||
|
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
|
||||||
|
Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [20, 16] for fractional max pooling.
|
||||||
|
|
||||||
|
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalMaxPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalMaxPool deterministic.
|
||||||
|
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
An second seed to avoid seed collision.
|
||||||
|
* <b>`name`</b>: A name for the operation (optional).
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
|
||||||
|
|
||||||
|
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
|
||||||
|
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
|
||||||
|
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
|
||||||
|
|
@ -0,0 +1,57 @@
|
|||||||
|
### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
|
||||||
|
|
||||||
|
Performs fractional average pooling on the input.
|
||||||
|
|
||||||
|
Fractional average pooling is similar to Fractional max pooling in the pooling
|
||||||
|
region generation step. The only difference is that after pooling regions are
|
||||||
|
generated, a mean operation is performed instead of a max operation in each
|
||||||
|
pooling region.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
|
||||||
|
4-D with shape `[batch, height, width, channels]`.
|
||||||
|
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
|
||||||
|
Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [41/3, 26/3] for fractional avg pooling.
|
||||||
|
|
||||||
|
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalAvgPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalAvgPool deterministic.
|
||||||
|
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
An second seed to avoid seed collision.
|
||||||
|
* <b>`name`</b>: A name for the operation (optional).
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
|
||||||
|
|
||||||
|
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
|
||||||
|
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
|
||||||
|
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
|
||||||
|
|
@ -473,6 +473,8 @@
|
|||||||
* [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse)
|
* [`embedding_lookup_sparse`](../../api_docs/python/nn.md#embedding_lookup_sparse)
|
||||||
* [`erosion2d`](../../api_docs/python/nn.md#erosion2d)
|
* [`erosion2d`](../../api_docs/python/nn.md#erosion2d)
|
||||||
* [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
|
* [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
|
||||||
|
* [`fractional_avg_pool`](../../api_docs/python/nn.md#fractional_avg_pool)
|
||||||
|
* [`fractional_max_pool`](../../api_docs/python/nn.md#fractional_max_pool)
|
||||||
* [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
|
* [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
|
||||||
* [`l2_loss`](../../api_docs/python/nn.md#l2_loss)
|
* [`l2_loss`](../../api_docs/python/nn.md#l2_loss)
|
||||||
* [`l2_normalize`](../../api_docs/python/nn.md#l2_normalize)
|
* [`l2_normalize`](../../api_docs/python/nn.md#l2_normalize)
|
||||||
|
@ -828,6 +828,151 @@ Performs 3D max pooling on the input.
|
|||||||
A `Tensor`. Has the same type as `input`. The max pooled output tensor.
|
A `Tensor`. Has the same type as `input`. The max pooled output tensor.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.nn.fractional_avg_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_avg_pool}
|
||||||
|
|
||||||
|
Performs fractional average pooling on the input.
|
||||||
|
|
||||||
|
Fractional average pooling is similar to Fractional max pooling in the pooling
|
||||||
|
region generation step. The only difference is that after pooling regions are
|
||||||
|
generated, a mean operation is performed instead of a max operation in each
|
||||||
|
pooling region.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
|
||||||
|
4-D with shape `[batch, height, width, channels]`.
|
||||||
|
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
|
||||||
|
Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [41/3, 26/3] for fractional avg pooling.
|
||||||
|
|
||||||
|
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalAvgPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalAvgPool deterministic.
|
||||||
|
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
An second seed to avoid seed collision.
|
||||||
|
* <b>`name`</b>: A name for the operation (optional).
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
|
||||||
|
|
||||||
|
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional avg pooling.
|
||||||
|
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
|
||||||
|
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.nn.fractional_max_pool(value, pooling_ratio, pseudo_random=None, overlapping=None, deterministic=None, seed=None, seed2=None, name=None)` {#fractional_max_pool}
|
||||||
|
|
||||||
|
Performs fractional max pooling on the input.
|
||||||
|
|
||||||
|
Fractional max pooling is slightly different than regular max pooling. In
|
||||||
|
regular max pooling, you downsize an input set by taking the maximum value of
|
||||||
|
smaller N x N subsections of the set (often 2x2), and try to reduce the set by
|
||||||
|
a factor of N, where N is an integer. Fractional max pooling, as you might
|
||||||
|
expect from the word "fractional", means that the overall reduction ratio N
|
||||||
|
does not have to be an integer.
|
||||||
|
|
||||||
|
The sizes of the pooling regions are generated randomly but are fairly uniform.
|
||||||
|
For example, let's look at the height dimension, and the constraints on the
|
||||||
|
list of rows that will be pool boundaries.
|
||||||
|
|
||||||
|
First we define the following:
|
||||||
|
|
||||||
|
1. input_row_length : the number of rows from the input set
|
||||||
|
2. output_row_length : which will be smaller than the input
|
||||||
|
3. alpha = input_row_length / output_row_length : our reduction ratio
|
||||||
|
4. K = floor(alpha)
|
||||||
|
5. row_pooling_sequence : this is the result list of pool boundary rows
|
||||||
|
|
||||||
|
Then, row_pooling_sequence should satisfy:
|
||||||
|
|
||||||
|
1. a[0] = 0 : the first value of the sequence is 0
|
||||||
|
2. a[end] = input_row_length : the last value of the sequence is the size
|
||||||
|
3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
|
||||||
|
4. length(row_pooling_sequence) = output_row_length+1
|
||||||
|
|
||||||
|
For more details on fractional max pooling, see this paper:
|
||||||
|
[Benjamin Graham, Fractional Max-Pooling]
|
||||||
|
(http://arxiv.org/abs/1412.6071)
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`value`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
|
||||||
|
4-D with shape `[batch, height, width, channels]`.
|
||||||
|
* <b>`pooling_ratio`</b>: A list of `floats` that has length `>= 4`.
|
||||||
|
Pooling ratio for each dimension of `value`, currently only
|
||||||
|
supports row and col dimension and should be >= 1.0. For example, a valid
|
||||||
|
pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements
|
||||||
|
must be 1.0 because we don't allow pooling on batch and channels
|
||||||
|
dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions
|
||||||
|
respectively.
|
||||||
|
* <b>`pseudo_random`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, generates the pooling sequence in a
|
||||||
|
pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin
|
||||||
|
Graham, Fractional Max-Pooling] (http://arxiv.org/abs/1412.6071) for
|
||||||
|
difference between pseudorandom and random.
|
||||||
|
* <b>`overlapping`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, it means when pooling, the values at the boundary
|
||||||
|
of adjacent pooling cells are used by both cells. For example:
|
||||||
|
|
||||||
|
`index 0 1 2 3 4`
|
||||||
|
|
||||||
|
`value 20 5 16 3 7`
|
||||||
|
|
||||||
|
If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice.
|
||||||
|
The result would be [20, 16] for fractional max pooling.
|
||||||
|
|
||||||
|
* <b>`deterministic`</b>: An optional `bool`. Defaults to `False`.
|
||||||
|
When set to True, a fixed pooling region will be used when
|
||||||
|
iterating over a FractionalMaxPool node in the computation graph. Mainly used
|
||||||
|
in unit test to make FractionalMaxPool deterministic.
|
||||||
|
* <b>`seed`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
If either seed or seed2 are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
* <b>`seed2`</b>: An optional `int`. Defaults to `0`.
|
||||||
|
An second seed to avoid seed collision.
|
||||||
|
* <b>`name`</b>: A name for the operation (optional).
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tuple of `Tensor` objects (output, row_pooling_sequence, col_pooling_sequence).
|
||||||
|
|
||||||
|
* <b>`output`</b>: A `Tensor`. Has the same type as `value`. output tensor after fractional max pooling.
|
||||||
|
* <b>`row_pooling_sequence`</b>: A `Tensor` of type `int64`. row pooling sequence, needed to calculate gradient.
|
||||||
|
* <b>`col_pooling_sequence`</b>: A `Tensor` of type `int64`. column pooling sequence, needed to calculate gradient.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Morphological filtering
|
## Morphological filtering
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ py_tests(
|
|||||||
"determinant_op_test.py",
|
"determinant_op_test.py",
|
||||||
"edit_distance_op_test.py",
|
"edit_distance_op_test.py",
|
||||||
"fifo_queue_test.py",
|
"fifo_queue_test.py",
|
||||||
|
"fractional_avg_pool_op_test.py",
|
||||||
|
"fractional_max_pool_op_test.py",
|
||||||
"identity_op_py_test.py",
|
"identity_op_py_test.py",
|
||||||
"in_topk_op_test.py",
|
"in_topk_op_test.py",
|
||||||
"io_ops_test.py",
|
"io_ops_test.py",
|
||||||
|
521
tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
Normal file
521
tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for fractional average pool operation."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
|
|
||||||
|
|
||||||
|
class FractionalAvgTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
# Random number generate with seed.
|
||||||
|
_PRNG = np.random.RandomState(341261000)
|
||||||
|
_SEED = 341261001
|
||||||
|
_SEED2 = 341261002
|
||||||
|
|
||||||
|
def _AvgPoolAlongRows(self, input_matrix, row_seq, overlapping):
|
||||||
|
"""Perform average pool along row of a 2-D matrix based on row_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_matrix: A 2-D matrix.
|
||||||
|
row_seq: Cumulative pooling sequence along row.
|
||||||
|
overlapping: Whether or not use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2-D matrix, with
|
||||||
|
* num_rows = len(row_seq)-1
|
||||||
|
* num_cols = input_matrix.num_cols.
|
||||||
|
"""
|
||||||
|
output_image = np.zeros(input_matrix.shape[1])
|
||||||
|
row_max = row_seq[-1]
|
||||||
|
for i in range(row_seq.shape[0] - 1):
|
||||||
|
row_start = row_seq[i]
|
||||||
|
row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
|
||||||
|
row_end = min(row_end, row_max)
|
||||||
|
output_image = np.vstack((output_image,
|
||||||
|
np.mean(input_matrix[row_start:row_end, :],
|
||||||
|
axis=0))) # axis 0 is along row
|
||||||
|
# remove the sentinel row
|
||||||
|
return output_image[1:, :]
|
||||||
|
|
||||||
|
def _AvgPoolAlongCols(self, input_matrix, col_seq, overlapping):
|
||||||
|
"""Perform average pool along column of a 2-D matrix based on col_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_matrix: A 2-D matrix.
|
||||||
|
col_seq: Cumulative pooling sequence along column.
|
||||||
|
overlapping: Whether or not use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2-D matrix, with
|
||||||
|
* num_rows = input_matrix.num_rows
|
||||||
|
* num_cols = len(col_seq)-1.
|
||||||
|
"""
|
||||||
|
input_matrix = input_matrix.transpose()
|
||||||
|
output_matrix = self._AvgPoolAlongRows(input_matrix, col_seq, overlapping)
|
||||||
|
return output_matrix.transpose()
|
||||||
|
|
||||||
|
def _GetExpectedFractionalAvgPoolResult(self, input_tensor, row_seq, col_seq,
|
||||||
|
overlapping):
|
||||||
|
"""Get expected fractional average pooling result.
|
||||||
|
|
||||||
|
row_seq and col_seq together defines the fractional pooling region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Original input tensor, assuming it is a 4-D tensor, with
|
||||||
|
dimension as [batch, height/row, width/column, channels/depth].
|
||||||
|
row_seq: Cumulative pooling sequence along row.
|
||||||
|
col_seq: Cumulative pooling sequence along column.
|
||||||
|
overlapping: Use overlapping when doing pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 4-D tensor that is the result of average pooling on input_tensor based
|
||||||
|
on pooling region defined by row_seq and col_seq, conditioned on whether
|
||||||
|
or not overlapping is used.
|
||||||
|
"""
|
||||||
|
input_shape = input_tensor.shape
|
||||||
|
output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
|
||||||
|
input_shape[3])
|
||||||
|
output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
|
||||||
|
for batch in range(input_shape[0]):
|
||||||
|
for channel in range(input_shape[3]):
|
||||||
|
two_dim_slice = input_tensor[batch, :, :, channel]
|
||||||
|
tmp = self._AvgPoolAlongRows(two_dim_slice, row_seq, overlapping)
|
||||||
|
output_tensor[batch, :, :, channel] = self._AvgPoolAlongCols(
|
||||||
|
tmp, col_seq, overlapping)
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def _ValidateFractionalAvgPoolResult(self, input_tensor, pooling_ratio,
|
||||||
|
pseudo_random, overlapping):
|
||||||
|
"""Validate FractionalAvgPool's result against expected.
|
||||||
|
|
||||||
|
Expected result is computed given input_tensor, and pooling region defined
|
||||||
|
by row_seq and col_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: A tensor or numpy ndarray.
|
||||||
|
pooling_ratio: A list or tuple of length 4, first and last element be 1.
|
||||||
|
pseudo_random: Use pseudo random method to generate pooling sequence.
|
||||||
|
overlapping: Use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
with self.test_session() as sess:
|
||||||
|
p, r, c = tf.nn.fractional_avg_pool(input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random,
|
||||||
|
overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
actual, row_seq, col_seq = sess.run([p, r, c])
|
||||||
|
expected = self._GetExpectedFractionalAvgPoolResult(input_tensor, row_seq,
|
||||||
|
col_seq, overlapping)
|
||||||
|
self.assertShapeEqual(expected, p)
|
||||||
|
self.assertAllClose(expected, actual)
|
||||||
|
|
||||||
|
def _testVisually(self):
|
||||||
|
"""Manual test by printing out intermediate result of a small random tensor.
|
||||||
|
|
||||||
|
Since _GetExpectedFractionalAvgPoolResult is 'automated', it feels safer to
|
||||||
|
have a test case that you can see what's happening.
|
||||||
|
This test will generate a small, random, int 2D matrix, and feed it to
|
||||||
|
FractionalAvgPool and _GetExpectedFractionalAvgPoolResult.
|
||||||
|
"""
|
||||||
|
num_rows = 6
|
||||||
|
num_cols = 6
|
||||||
|
tensor_shape = (1, num_rows, num_cols, 1)
|
||||||
|
pseudo_random = False
|
||||||
|
for overlapping in True, False:
|
||||||
|
print("-" * 70)
|
||||||
|
print("Testing FractionalAvgPool with overlapping = {}".format(
|
||||||
|
overlapping))
|
||||||
|
rand_mat = self._PRNG.randint(10, size=tensor_shape)
|
||||||
|
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
|
||||||
|
with self.test_session() as sess:
|
||||||
|
p, r, c = tf.nn.fractional_avg_pool(
|
||||||
|
rand_mat.astype(np.float32),
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random,
|
||||||
|
overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
tensor_output, row_seq, col_seq = sess.run([p, r, c])
|
||||||
|
expected_result = self._GetExpectedFractionalAvgPoolResult(
|
||||||
|
rand_mat.astype(np.float32), row_seq, col_seq, overlapping)
|
||||||
|
print("row sequence:")
|
||||||
|
print(row_seq)
|
||||||
|
print("column sequence:")
|
||||||
|
print(col_seq)
|
||||||
|
|
||||||
|
print("Input:")
|
||||||
|
# Print input with pooling region marked.
|
||||||
|
for i in range(num_rows):
|
||||||
|
row_to_print = []
|
||||||
|
for j in range(num_cols):
|
||||||
|
if j in col_seq:
|
||||||
|
row_to_print.append("|")
|
||||||
|
row_to_print.append(str(rand_mat[0, i, j, 0]))
|
||||||
|
row_to_print.append("|")
|
||||||
|
if i in row_seq:
|
||||||
|
print("-" * 2 * len(row_to_print))
|
||||||
|
print(" ".join(row_to_print))
|
||||||
|
print("-" * 2 * len(row_to_print))
|
||||||
|
|
||||||
|
print("Output from FractionalAvgPool:")
|
||||||
|
print(tensor_output[0, :, :, 0])
|
||||||
|
print("Expected result:")
|
||||||
|
print(expected_result[0, :, :, 0])
|
||||||
|
|
||||||
|
def testAllInputOptions(self):
|
||||||
|
"""Try all possible input options for fractional_avg_pool.
|
||||||
|
"""
|
||||||
|
num_batches = 5
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 20
|
||||||
|
num_cols = 30
|
||||||
|
for pseudo_random in True, False:
|
||||||
|
for overlapping in True, False:
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalAvgPoolResult(
|
||||||
|
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
def testIntegerTensorInput(self):
|
||||||
|
"""Test FractionalAvgPool works fine when input tensor is integer type.
|
||||||
|
|
||||||
|
I would have used _ValidateFractionalAvgPoolResult function to automate this
|
||||||
|
process, however, there's rounding issue. It is caused by numpy.mean cast
|
||||||
|
integer input to numpy.float64 for intermediate use. While for
|
||||||
|
fractional_avg_pool, the mean operation is integer division (trucated). So,
|
||||||
|
for this test case, I will hard code a simple matrix.
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
tensor_shape = (1, 6, 6, 1)
|
||||||
|
# pyformat: disable
|
||||||
|
mat = np.array([
|
||||||
|
[2, 6, 4, 1, 3, 6],
|
||||||
|
[8, 9, 1, 6, 6, 8],
|
||||||
|
[3, 9, 8, 2, 5, 6],
|
||||||
|
[2, 7, 9, 5, 4, 5],
|
||||||
|
[8, 5, 0, 5, 7, 4],
|
||||||
|
[4, 4, 5, 9, 7, 2]
|
||||||
|
])
|
||||||
|
# pyformat: enable
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
|
||||||
|
# are the same each time. We can have an expected result precomputed.
|
||||||
|
# r = [0, 2, 4, 6]
|
||||||
|
# c = [0, 1, 3, 4, 6]
|
||||||
|
|
||||||
|
# pyformat: disable
|
||||||
|
expected = np.array([
|
||||||
|
[6, 5, 3, 5],
|
||||||
|
[5, 5, 4, 5],
|
||||||
|
[5, 4, 7, 5]
|
||||||
|
]).reshape((1, 3, 4, 1))
|
||||||
|
# pyformat: enable
|
||||||
|
p, unused_r, unused_c = tf.nn.fractional_avg_pool(
|
||||||
|
mat.reshape(tensor_shape), [1, math.sqrt(3), math.sqrt(2), 1],
|
||||||
|
pseudo_random,
|
||||||
|
overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
actual = sess.run(p)
|
||||||
|
self.assertShapeEqual(expected, p)
|
||||||
|
self.assertAllClose(expected, actual)
|
||||||
|
|
||||||
|
def testDifferentTensorShapes(self):
|
||||||
|
"""Test different shapes of input tensor.
|
||||||
|
|
||||||
|
Mainly test different combinations of num_rows and num_cols.
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for num_channels in [1, 3]:
|
||||||
|
for num_rows in [10, 20, 50]:
|
||||||
|
for num_cols in [10, 20, 50]:
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalAvgPoolResult(
|
||||||
|
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
def testLargePoolingRatio(self):
|
||||||
|
"""Test when pooling ratio is not within [1, 2).
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
num_batches = 3
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 30
|
||||||
|
num_cols = 50
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
for row_ratio in [math.sqrt(11), math.sqrt(37)]:
|
||||||
|
for col_ratio in [math.sqrt(11), math.sqrt(27)]:
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalAvgPoolResult(rand_mat,
|
||||||
|
[1, row_ratio, col_ratio, 1],
|
||||||
|
pseudo_random, overlapping)
|
||||||
|
|
||||||
|
def testDivisiblePoolingRatio(self):
|
||||||
|
"""Test when num of rows/cols can evenly divide pooling ratio.
|
||||||
|
|
||||||
|
This is a case regular average pooling can handle. Should be handled by
|
||||||
|
fractional pooling as well.
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
num_batches = 3
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 30
|
||||||
|
num_cols = 50
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
|
||||||
|
class FractionalAvgPoolGradTest(tf.test.TestCase):
|
||||||
|
"""Tests for FractionalAvgPoolGrad.
|
||||||
|
|
||||||
|
Two types of tests for FractionalAvgPoolGrad.
|
||||||
|
1) Test fractional_avg_pool_grad() directly.
|
||||||
|
This type of test relies on gen_nn_ops._avg_pool_grad() returns the
|
||||||
|
correct result. For example:
|
||||||
|
* input_tensor_shape = (1, 10, 10, 1)
|
||||||
|
* window_size = (1, 2, 2, 1)
|
||||||
|
* stride_size = (1, 2, 2, 1)
|
||||||
|
* padding: not really important, since 10/2 is divisible
|
||||||
|
avg pooling should generate the same result as fractional avg pooling with:
|
||||||
|
* row_sequence = [0, 2, 4, 6, 8, 10]
|
||||||
|
* col_sequence = [0, 2, 4, 6, 8, 10]
|
||||||
|
* overlapping = False
|
||||||
|
This also means their gradients in such case will be the same.
|
||||||
|
|
||||||
|
Similarly, when
|
||||||
|
* input_tensor_shape = (1, 7, 7, 1)
|
||||||
|
* window_size = (1, 3, 3, 1)
|
||||||
|
* stride_size = (1, 2, 2, 1)
|
||||||
|
* padding: not important
|
||||||
|
avg pooling should generate the same result as fractional avg pooling with:
|
||||||
|
* row_sequence = [0, 2, 4, 7]
|
||||||
|
* col_sequence = [0, 2, 4, 7]
|
||||||
|
* overlapping = True
|
||||||
|
2) Test through compute_gradient_error()
|
||||||
|
"""
|
||||||
|
_PRNG = np.random.RandomState(341261004)
|
||||||
|
_SEED = 341261005
|
||||||
|
_SEED2 = 341261006
|
||||||
|
|
||||||
|
def _GenerateRandomInputTensor(self, shape):
|
||||||
|
num_elements = 1
|
||||||
|
for dim_size in shape:
|
||||||
|
num_elements *= dim_size
|
||||||
|
x = self._PRNG.rand(num_elements) * 1000
|
||||||
|
return x.reshape(shape)
|
||||||
|
|
||||||
|
def testDirectNotUseOverlapping(self):
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for row_window_size in [2, 5]:
|
||||||
|
for col_window_size in [2, 4]:
|
||||||
|
num_rows = row_window_size * 5
|
||||||
|
num_cols = col_window_size * 7
|
||||||
|
for num_channels in [1, 2]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(self._GenerateRandomInputTensor(
|
||||||
|
input_shape).astype(np.float32))
|
||||||
|
window_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
stride_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
padding = "VALID"
|
||||||
|
output_tensor = tf.nn.avg_pool(input_tensor, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
num_elements = 1
|
||||||
|
for dim_size in output_data.shape:
|
||||||
|
num_elements *= dim_size
|
||||||
|
output_backprop = (self._PRNG.rand(num_elements) *
|
||||||
|
1000).reshape(output_data.shape)
|
||||||
|
input_backprop_tensor = gen_nn_ops._avg_pool_grad(
|
||||||
|
input_tensor.get_shape(), output_backprop, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
input_backprop = input_backprop_tensor.eval()
|
||||||
|
row_seq = list(range(0, num_rows + 1, row_window_size))
|
||||||
|
col_seq = list(range(0, num_cols + 1, col_window_size))
|
||||||
|
fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
|
||||||
|
input_tensor.get_shape(),
|
||||||
|
output_backprop,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping=False)
|
||||||
|
fap_input_backprop = fap_input_backprop_tensor.eval()
|
||||||
|
self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
|
||||||
|
self.assertAllClose(input_backprop, fap_input_backprop)
|
||||||
|
|
||||||
|
def testDirectUseOverlapping(self):
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for row_window_size in [2, 5]:
|
||||||
|
for col_window_size in [2, 4]:
|
||||||
|
num_rows = (row_window_size - 1) * 5 + 1
|
||||||
|
num_cols = (col_window_size - 1) * 7 + 1
|
||||||
|
for num_channels in [1, 2]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(self._GenerateRandomInputTensor(
|
||||||
|
input_shape).astype(np.float32))
|
||||||
|
window_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
|
||||||
|
padding = "VALID"
|
||||||
|
output_tensor = tf.nn.avg_pool(input_tensor, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
num_elements = 1
|
||||||
|
for dim_size in output_data.shape:
|
||||||
|
num_elements *= dim_size
|
||||||
|
output_backprop = (self._PRNG.rand(num_elements) *
|
||||||
|
1000).reshape(output_data.shape)
|
||||||
|
input_backprop_tensor = gen_nn_ops._avg_pool_grad(
|
||||||
|
input_tensor.get_shape(), output_backprop, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
input_backprop = input_backprop_tensor.eval()
|
||||||
|
row_seq = list(range(0, num_rows, row_window_size - 1))
|
||||||
|
col_seq = list(range(0, num_cols, col_window_size - 1))
|
||||||
|
row_seq[-1] += 1
|
||||||
|
col_seq[-1] += 1
|
||||||
|
fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad(
|
||||||
|
input_tensor.get_shape(),
|
||||||
|
output_backprop,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping=True)
|
||||||
|
fap_input_backprop = fap_input_backprop_tensor.eval()
|
||||||
|
self.assertShapeEqual(input_backprop, fap_input_backprop_tensor)
|
||||||
|
self.assertAllClose(input_backprop, fap_input_backprop)
|
||||||
|
|
||||||
|
def testAllInputOptionsThroughGradientError(self):
|
||||||
|
input_shape = (1, 7, 13, 1)
|
||||||
|
input_data = self._GenerateRandomInputTensor(input_shape)
|
||||||
|
pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
|
||||||
|
|
||||||
|
for pseudo_random in True, False:
|
||||||
|
for overlapping in True, False:
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_shape = output_data.shape
|
||||||
|
# error_margin and delta setting is similar to avg_pool_grad.
|
||||||
|
error_margin = 1e-4
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
def testDifferentTensorShapesThroughGradientError(self):
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
|
||||||
|
for num_batches in [1, 2]:
|
||||||
|
for num_rows in [5, 13]:
|
||||||
|
for num_cols in [5, 11]:
|
||||||
|
for num_channels in [1, 3]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
input_data = self._GenerateRandomInputTensor(input_shape)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_shape = output_data.shape
|
||||||
|
# error_margin and delta setting is similar to avg_pool_grad.
|
||||||
|
error_margin = 1e-4
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
def testLargePoolingRatioThroughGradientError(self):
|
||||||
|
input_shape = (1, 17, 23, 1)
|
||||||
|
input_data = self._GenerateRandomInputTensor(input_shape)
|
||||||
|
pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
|
||||||
|
output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
|
||||||
|
overlapping = True
|
||||||
|
pseudo_random = False
|
||||||
|
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_avg_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
# error_margin and delta setting is similar to avg_pool_grad.
|
||||||
|
error_margin = 1e-4
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
582
tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
Normal file
582
tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for fractional max pool operation."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
|
|
||||||
|
|
||||||
|
class FractionalMaxPoolTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
# Random number generate with seed.
|
||||||
|
_PRNG = np.random.RandomState(341261)
|
||||||
|
_SEED = 123456
|
||||||
|
_SEED2 = 654321
|
||||||
|
|
||||||
|
def _MaxPoolAlongRows(self, input_matrix, row_seq, overlapping):
|
||||||
|
"""Perform max pool along row of a 2-D matrix based on row_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_matrix: A 2-D matrix.
|
||||||
|
row_seq: Cumulative pooling sequence along row.
|
||||||
|
overlapping: Whether or not use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2-D matrix, with
|
||||||
|
* num_rows = len(row_seq)-1
|
||||||
|
* num_cols = input_matrix.num_cols.
|
||||||
|
"""
|
||||||
|
output_image = np.zeros(input_matrix.shape[1])
|
||||||
|
row_max = row_seq[-1]
|
||||||
|
for i in range(row_seq.shape[0] - 1):
|
||||||
|
row_start = row_seq[i]
|
||||||
|
row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
|
||||||
|
row_end = min(row_end, row_max)
|
||||||
|
output_image = np.vstack((output_image,
|
||||||
|
np.amax(input_matrix[row_start:row_end, :],
|
||||||
|
axis=0))) # axis 0 is along row
|
||||||
|
# remove the sentinel row
|
||||||
|
return output_image[1:, :]
|
||||||
|
|
||||||
|
def _MaxPoolAlongCols(self, input_matrix, col_seq, overlapping):
|
||||||
|
"""Perform max pool along column of a 2-D matrix based on col_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_matrix: A 2-D matrix.
|
||||||
|
col_seq: Cumulative pooling sequence along column.
|
||||||
|
overlapping: Whether or not use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2-D matrix, with
|
||||||
|
* num_rows = input_matrix.num_rows
|
||||||
|
* num_cols = len(col_seq)-1.
|
||||||
|
"""
|
||||||
|
input_matrix = input_matrix.transpose()
|
||||||
|
output_matrix = self._MaxPoolAlongRows(input_matrix, col_seq, overlapping)
|
||||||
|
return output_matrix.transpose()
|
||||||
|
|
||||||
|
def _GetExpectedFractionalMaxPoolResult(self, input_tensor, row_seq, col_seq,
|
||||||
|
overlapping):
|
||||||
|
"""Get expected fractional max pool result.
|
||||||
|
|
||||||
|
row_seq and col_seq together defines the fractional pooling region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Original input tensor, assuming it is a 4-D tensor, with
|
||||||
|
dimension as [batch, height/row, width/column, channels/depth].
|
||||||
|
row_seq: Cumulative pooling sequence along row.
|
||||||
|
col_seq: Cumulative pooling sequence along column.
|
||||||
|
overlapping: Use overlapping when doing pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 4-D tensor that is the result of max pooling on input_tensor based on
|
||||||
|
pooling region defined by row_seq and col_seq, conditioned on whether or
|
||||||
|
not overlapping is used.
|
||||||
|
"""
|
||||||
|
input_shape = input_tensor.shape
|
||||||
|
output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
|
||||||
|
input_shape[3])
|
||||||
|
output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
|
||||||
|
for batch in range(input_shape[0]):
|
||||||
|
for channel in range(input_shape[3]):
|
||||||
|
two_dim_slice = input_tensor[batch, :, :, channel]
|
||||||
|
tmp = self._MaxPoolAlongRows(two_dim_slice, row_seq, overlapping)
|
||||||
|
output_tensor[batch, :, :, channel] = self._MaxPoolAlongCols(
|
||||||
|
tmp, col_seq, overlapping)
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def _ValidateFractionalMaxPoolResult(self, input_tensor, pooling_ratio,
|
||||||
|
pseudo_random, overlapping):
|
||||||
|
"""Validate FractionalMaxPool's result against expected.
|
||||||
|
|
||||||
|
Expected result is computed given input_tensor, and pooling region defined
|
||||||
|
by row_seq and col_seq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: A tensor or numpy ndarray.
|
||||||
|
pooling_ratio: A list or tuple of length 4, first and last element be 1.
|
||||||
|
pseudo_random: Use pseudo random method to generate pooling sequence.
|
||||||
|
overlapping: Use overlapping when pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
with self.test_session() as sess:
|
||||||
|
p, r, c = tf.nn.fractional_max_pool(input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random,
|
||||||
|
overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
actual, row_seq, col_seq = sess.run([p, r, c])
|
||||||
|
expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq,
|
||||||
|
col_seq, overlapping)
|
||||||
|
self.assertShapeEqual(expected, p)
|
||||||
|
self.assertAllClose(expected, actual)
|
||||||
|
|
||||||
|
def _testVisually(self):
|
||||||
|
"""Manual test by printing out intermediate result of a small random tensor.
|
||||||
|
|
||||||
|
Since _GetExpectedFractionalMaxPoolResult is 'automated', it feel safer to
|
||||||
|
have a test case that you can see what's happening.
|
||||||
|
This test will generate a small, random, int 2D matrix, and feed it to
|
||||||
|
FractinalMaxPool and _GetExpectedFractionalMaxPoolResult.
|
||||||
|
"""
|
||||||
|
num_rows = 6
|
||||||
|
num_cols = 6
|
||||||
|
tensor_shape = (1, num_rows, num_cols, 1)
|
||||||
|
pseudo_random = False
|
||||||
|
for overlapping in True, False:
|
||||||
|
print("-" * 70)
|
||||||
|
print("Testing FractionalMaxPool with overlapping = {}".format(
|
||||||
|
overlapping))
|
||||||
|
rand_mat = self._PRNG.randint(10, size=tensor_shape)
|
||||||
|
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
|
||||||
|
with self.test_session() as sess:
|
||||||
|
p, r, c = tf.nn.fractional_max_pool(rand_mat,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random,
|
||||||
|
overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
tensor_output, row_seq, col_seq = sess.run([p, r, c])
|
||||||
|
expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping)
|
||||||
|
print("row sequence:")
|
||||||
|
print(row_seq)
|
||||||
|
print("column sequence:")
|
||||||
|
print(col_seq)
|
||||||
|
|
||||||
|
print("Input:")
|
||||||
|
# Print input with pooling region marked.
|
||||||
|
for i in range(num_rows):
|
||||||
|
row_to_print = []
|
||||||
|
for j in range(num_cols):
|
||||||
|
if j in col_seq:
|
||||||
|
row_to_print.append("|")
|
||||||
|
row_to_print.append(str(rand_mat[0, i, j, 0]))
|
||||||
|
row_to_print.append("|")
|
||||||
|
if i in row_seq:
|
||||||
|
print("-" * 2 * len(row_to_print))
|
||||||
|
print(" ".join(row_to_print))
|
||||||
|
print("-" * 2 * len(row_to_print))
|
||||||
|
|
||||||
|
print("Output from FractionalMaxPool:")
|
||||||
|
print(tensor_output[0, :, :, 0])
|
||||||
|
print("Expected result:")
|
||||||
|
print(expected_result[0, :, :, 0])
|
||||||
|
|
||||||
|
def testAllInputOptions(self):
|
||||||
|
"""Try all possible input options for fractional_max_pool.
|
||||||
|
"""
|
||||||
|
num_batches = 5
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 20
|
||||||
|
num_cols = 30
|
||||||
|
for pseudo_random in True, False:
|
||||||
|
for overlapping in True, False:
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalMaxPoolResult(
|
||||||
|
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
def testIntegerTensorInput(self):
|
||||||
|
"""Test it works fine when input tensor is integer type.
|
||||||
|
"""
|
||||||
|
num_batches = 5
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 20
|
||||||
|
num_cols = 30
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
rand_mat = self._PRNG.randint(1000, size=tensor_shape)
|
||||||
|
self._ValidateFractionalMaxPoolResult(rand_mat,
|
||||||
|
[1, math.sqrt(3), math.sqrt(2), 1],
|
||||||
|
pseudo_random, overlapping)
|
||||||
|
|
||||||
|
def testDifferentTensorShapes(self):
|
||||||
|
"""Test different shapes of input tensor.
|
||||||
|
|
||||||
|
Mainly test different combinations of num_rows and num_cols.
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for num_channels in [1, 3]:
|
||||||
|
for num_rows in [10, 20, 50]:
|
||||||
|
for num_cols in [10, 20, 50]:
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalMaxPoolResult(
|
||||||
|
rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
def testLargePoolingRatio(self):
|
||||||
|
"""Test when pooling ratio is not within [1, 2).
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
num_batches = 3
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 30
|
||||||
|
num_cols = 50
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
for row_ratio in [math.sqrt(11), math.sqrt(37)]:
|
||||||
|
for col_ratio in [math.sqrt(11), math.sqrt(27)]:
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalMaxPoolResult(rand_mat,
|
||||||
|
[1, row_ratio, col_ratio, 1],
|
||||||
|
pseudo_random, overlapping)
|
||||||
|
|
||||||
|
def testDivisiblePoolingRatio(self):
|
||||||
|
"""Test when num of rows/cols can evenly divide pooling ratio.
|
||||||
|
|
||||||
|
This is a case regular max pooling can handle. Should be handled by
|
||||||
|
fractional pooling as well.
|
||||||
|
"""
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
num_batches = 3
|
||||||
|
num_channels = 3
|
||||||
|
num_rows = 30
|
||||||
|
num_cols = 50
|
||||||
|
tensor_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
# random tensor with value in [-500.0, 500.0)
|
||||||
|
rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
|
||||||
|
self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
|
||||||
|
overlapping)
|
||||||
|
|
||||||
|
|
||||||
|
class FractionalMaxPoolGradTest(tf.test.TestCase):
|
||||||
|
"""Tests for FractionalMaxPoolGrad.
|
||||||
|
|
||||||
|
Two types of tests for FractionalMaxPoolGrad.
|
||||||
|
1) Test fractional_max_pool_grad() directly.
|
||||||
|
This type of test relies on gen_nn_ops._max_pool_grad() returns the correct
|
||||||
|
result. For example:
|
||||||
|
* input_tensor_shape = (1, 10, 10, 1)
|
||||||
|
* window_size = (1, 2, 2, 1)
|
||||||
|
* stride_size = (1, 2, 2, 1)
|
||||||
|
* padding: not really import, since 10/2 is divisible
|
||||||
|
max pooling should generate the same result as fractional max pooling with:
|
||||||
|
* row_sequence = [0, 2, 4, 6, 8, 10]
|
||||||
|
* col_sequence = [0, 2, 4, 6, 8, 10]
|
||||||
|
* overlapping = False
|
||||||
|
This also means their gradients in such case will be the same.
|
||||||
|
|
||||||
|
Similarly, when
|
||||||
|
* input_tensor_shape = (1, 7, 7, 1)
|
||||||
|
* window_size = (1, 3, 3, 1)
|
||||||
|
* stride_size = (1, 2, 2, 1)
|
||||||
|
* padding: not important
|
||||||
|
max pooling should generate the same result as fractional max pooling with:
|
||||||
|
* row_sequence = [0, 2, 4, 7]
|
||||||
|
* col_sequence = [0, 2, 4, 7]
|
||||||
|
* overlapping = True
|
||||||
|
2) Test through compute_gradient_error()
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PRNG = np.random.RandomState(341261)
|
||||||
|
_SEED = 123456
|
||||||
|
_SEED2 = 654321
|
||||||
|
|
||||||
|
def _GenerateUniqueRandomInputTensor(self, shape):
|
||||||
|
"""Generate 'unqiue' random input tensor.
|
||||||
|
|
||||||
|
'Unique' means there's no collision values in the tensor, all elements are
|
||||||
|
different. This is done by generating sequence of integers with step of 1
|
||||||
|
and then randomly shuffle these integers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Shape of the tensor desired.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy ndarray with size = shape and dtype = numpy.float32.
|
||||||
|
"""
|
||||||
|
num_elements = 1
|
||||||
|
for size in shape:
|
||||||
|
num_elements *= size
|
||||||
|
x = np.arange(num_elements, dtype=np.float32)
|
||||||
|
self._PRNG.shuffle(x)
|
||||||
|
return x.reshape(shape)
|
||||||
|
|
||||||
|
def testDirectNotUseOverlapping(self):
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for row_window_size in [2, 5]:
|
||||||
|
for col_window_size in [2, 4]:
|
||||||
|
num_rows = row_window_size * 5
|
||||||
|
num_cols = col_window_size * 7
|
||||||
|
for num_channels in [1, 2]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
|
||||||
|
input_shape))
|
||||||
|
window_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
stride_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
padding = "VALID"
|
||||||
|
output_tensor = tf.nn.max_pool(input_tensor, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_backprop = self._PRNG.randint(100, size=output_data.shape)
|
||||||
|
input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
|
||||||
|
output_tensor,
|
||||||
|
output_backprop,
|
||||||
|
window_size,
|
||||||
|
stride_size,
|
||||||
|
padding)
|
||||||
|
input_backprop = input_backprop_tensor.eval()
|
||||||
|
row_seq = list(range(0, num_rows + 1, row_window_size))
|
||||||
|
col_seq = list(range(0, num_cols + 1, col_window_size))
|
||||||
|
fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
|
||||||
|
input_tensor,
|
||||||
|
output_tensor,
|
||||||
|
output_backprop,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping=False)
|
||||||
|
fmp_input_backprop = fmp_input_backprop_tensor.eval()
|
||||||
|
self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
|
||||||
|
self.assertAllClose(input_backprop, fmp_input_backprop)
|
||||||
|
|
||||||
|
def testDirectUseOverlapping(self):
|
||||||
|
for num_batches in [1, 3]:
|
||||||
|
for row_window_size in [2, 5]:
|
||||||
|
for col_window_size in [2, 4]:
|
||||||
|
num_rows = (row_window_size - 1) * 5 + 1
|
||||||
|
num_cols = (col_window_size - 1) * 7 + 1
|
||||||
|
for num_channels in [1, 2]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(self._GenerateUniqueRandomInputTensor(
|
||||||
|
input_shape))
|
||||||
|
window_size = [1, row_window_size, col_window_size, 1]
|
||||||
|
stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
|
||||||
|
padding = "VALID"
|
||||||
|
output_tensor = tf.nn.max_pool(input_tensor, window_size,
|
||||||
|
stride_size, padding)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_backprop = self._PRNG.randint(100, size=output_data.shape)
|
||||||
|
input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor,
|
||||||
|
output_tensor,
|
||||||
|
output_backprop,
|
||||||
|
window_size,
|
||||||
|
stride_size,
|
||||||
|
padding)
|
||||||
|
input_backprop = input_backprop_tensor.eval()
|
||||||
|
row_seq = list(range(0, num_rows, row_window_size - 1))
|
||||||
|
col_seq = list(range(0, num_cols, col_window_size - 1))
|
||||||
|
row_seq[-1] += 1
|
||||||
|
col_seq[-1] += 1
|
||||||
|
fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad(
|
||||||
|
input_tensor,
|
||||||
|
output_tensor,
|
||||||
|
output_backprop,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping=True)
|
||||||
|
fmp_input_backprop = fmp_input_backprop_tensor.eval()
|
||||||
|
self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
|
||||||
|
self.assertAllClose(input_backprop, fmp_input_backprop)
|
||||||
|
|
||||||
|
def testAllInputOptionsThroughGradientError(self):
|
||||||
|
input_shape = (1, 7, 13, 1)
|
||||||
|
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
|
||||||
|
# Add some randomness to make input_data not so 'integer'
|
||||||
|
input_data += self._PRNG.random_sample(input_shape)
|
||||||
|
pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
|
||||||
|
|
||||||
|
for pseudo_random in True, False:
|
||||||
|
for overlapping in True, False:
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_shape = output_data.shape
|
||||||
|
# error_margin and delta setting is similar to max_pool_grad.
|
||||||
|
error_margin = 1e-3
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
def testDifferentTensorShapesThroughGradientError(self):
|
||||||
|
pseudo_random = True
|
||||||
|
overlapping = True
|
||||||
|
pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
|
||||||
|
for num_batches in [1, 2]:
|
||||||
|
for num_rows in [5, 13]:
|
||||||
|
for num_cols in [5, 11]:
|
||||||
|
for num_channels in [1, 3]:
|
||||||
|
input_shape = (num_batches, num_rows, num_cols, num_channels)
|
||||||
|
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
|
||||||
|
# Add some randomness to make input_data not so 'integer'
|
||||||
|
input_data += self._PRNG.random_sample(input_shape)
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
output_data = output_tensor.eval()
|
||||||
|
output_shape = output_data.shape
|
||||||
|
# error_margin and delta setting is similar to max_pool_grad.
|
||||||
|
error_margin = 1e-3
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
def testLargePoolingRatioThroughGradientError(self):
|
||||||
|
input_shape = (1, 17, 23, 1)
|
||||||
|
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
|
||||||
|
# Add some randomness to make input_data not so 'integer'
|
||||||
|
input_data += self._PRNG.random_sample(input_shape)
|
||||||
|
pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
|
||||||
|
output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
|
||||||
|
overlapping = True
|
||||||
|
pseudo_random = False
|
||||||
|
|
||||||
|
with self.test_session() as _:
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_shape)
|
||||||
|
output_tensor, unused_a, unused_b = tf.nn.fractional_max_pool(
|
||||||
|
input_tensor,
|
||||||
|
pooling_ratio,
|
||||||
|
pseudo_random=pseudo_random,
|
||||||
|
overlapping=overlapping,
|
||||||
|
deterministic=True,
|
||||||
|
seed=self._SEED,
|
||||||
|
seed2=self._SEED2)
|
||||||
|
# error_margin and delta setting is similar to max_pool_grad.
|
||||||
|
error_margin = 1e-3
|
||||||
|
gradient_error = tf.test.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
input_shape,
|
||||||
|
output_tensor,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=input_data.reshape(input_shape),
|
||||||
|
delta=1e-2)
|
||||||
|
self.assertLess(gradient_error, error_margin)
|
||||||
|
|
||||||
|
def testWhenRepeatedMaxValueInPoolingRegion(self):
|
||||||
|
"""Test when there's repeating value in pooling region.
|
||||||
|
|
||||||
|
There's no formal definition for what the gradient should be when there're
|
||||||
|
multiple max value within a pooling cell. Such as
|
||||||
|
| 1 5 |
|
||||||
|
| 5 3 |
|
||||||
|
The expected result depends heavily on implementation, if someone swap the
|
||||||
|
order of a nested for loop when walking through the tensor, result would be
|
||||||
|
very different.
|
||||||
|
|
||||||
|
The goal of this test is to alert when someone else change the
|
||||||
|
implementation. Current implementation scans row-by-row.
|
||||||
|
"""
|
||||||
|
input_data = [5.0, 4.0, 6.0, 7.0,
|
||||||
|
3.0, 5.0, 9.0, 6.0,
|
||||||
|
8.0, 8.0, 9.0, 5.0,
|
||||||
|
7.0, 4.0, 0.0, 0.0] # pyformat: disable
|
||||||
|
input_size = [1, 4, 4, 1]
|
||||||
|
output_backprop = [12.0, 15.0,
|
||||||
|
17.0, -5.0,
|
||||||
|
6.0, 21.0] # pyformat: disable
|
||||||
|
row_seq = [0, 1, 3, 4]
|
||||||
|
col_seq = [0, 2, 4]
|
||||||
|
output_data_not_overlapping = [5.0, 7.0,
|
||||||
|
8.0, 9.0,
|
||||||
|
7.0, 0.0] # pyformat: disable
|
||||||
|
output_data_overlapping = [9.0, 9.0,
|
||||||
|
9.0, 9.0,
|
||||||
|
7.0, 0.0] # pyformat: disable
|
||||||
|
output_size = [1, 3, 2, 1]
|
||||||
|
expected_input_backprop_not_overlapping = np.reshape(
|
||||||
|
[12.0, 0.0, 0.0, 15.0,
|
||||||
|
0.0, 0.0, -5.0, 0.0,
|
||||||
|
17.0, 0.0, 0.0, 0.0,
|
||||||
|
6.0, 0.0, 21.0, 0.0],
|
||||||
|
input_size) # pyformat: disable
|
||||||
|
expected_input_backprop_overlapping = np.reshape(
|
||||||
|
[0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 39.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0,
|
||||||
|
6.0, 0.0, 21.0, 0.0],
|
||||||
|
input_size) # pyformat: disable
|
||||||
|
with self.test_session() as _:
|
||||||
|
# Test when overlapping is False
|
||||||
|
input_tensor = tf.constant(input_data, shape=input_size)
|
||||||
|
output_tensor = tf.constant(output_data_not_overlapping,
|
||||||
|
shape=output_size)
|
||||||
|
grad = tf.constant(output_backprop, shape=output_size)
|
||||||
|
r = gen_nn_ops._fractional_max_pool_grad(
|
||||||
|
input_tensor,
|
||||||
|
output_tensor,
|
||||||
|
grad,
|
||||||
|
row_seq,
|
||||||
|
col_seq,
|
||||||
|
overlapping=False)
|
||||||
|
input_backprop_not_overlapping = r.eval()
|
||||||
|
self.assertShapeEqual(
|
||||||
|
np.reshape(expected_input_backprop_not_overlapping, input_size), r)
|
||||||
|
self.assertAllClose(expected_input_backprop_not_overlapping,
|
||||||
|
input_backprop_not_overlapping)
|
||||||
|
# Test when overlapping is True
|
||||||
|
output_tensor = tf.constant(output_data_overlapping, shape=output_size)
|
||||||
|
r = gen_nn_ops._fractional_max_pool_grad(
|
||||||
|
input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True)
|
||||||
|
input_backprop_overlapping = r.eval()
|
||||||
|
self.assertShapeEqual(
|
||||||
|
np.reshape(expected_input_backprop_overlapping, input_size), r)
|
||||||
|
self.assertAllClose(expected_input_backprop_overlapping,
|
||||||
|
input_backprop_overlapping)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -181,6 +181,8 @@ AvgPool
|
|||||||
MaxPool
|
MaxPool
|
||||||
Softmax
|
Softmax
|
||||||
LogSoftmax
|
LogSoftmax
|
||||||
|
FractionalAvgPoolGrad
|
||||||
|
FractionalMaxPoolGrad
|
||||||
|
|
||||||
# parsing_ops
|
# parsing_ops
|
||||||
ParseExample
|
ParseExample
|
||||||
|
@ -132,6 +132,8 @@ to the `Convolution` section for details about the padding calculation.
|
|||||||
@@max_pool_with_argmax
|
@@max_pool_with_argmax
|
||||||
@@avg_pool3d
|
@@avg_pool3d
|
||||||
@@max_pool3d
|
@@max_pool3d
|
||||||
|
@@fractional_avg_pool
|
||||||
|
@@fractional_max_pool
|
||||||
|
|
||||||
## Morphological filtering
|
## Morphological filtering
|
||||||
|
|
||||||
|
@ -361,6 +361,53 @@ def _MaxPoolGrad(op, grad):
|
|||||||
data_format=op.get_attr("data_format"))
|
data_format=op.get_attr("data_format"))
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("FractionalMaxPool")
|
||||||
|
def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
|
||||||
|
"""Returns gradient for FractionalMaxPool.
|
||||||
|
|
||||||
|
Since FractionalMaxPool has three outputs, there are three gradients passed in
|
||||||
|
for each of the outputs. Only the first one is useful, the other two gradients
|
||||||
|
are empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: The FractionalMaxPoolOp.
|
||||||
|
grad_0: Gradient with respect to op.outputs[0]
|
||||||
|
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
|
||||||
|
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Input backprop for FractionalMaxPool op.
|
||||||
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0],
|
||||||
|
grad_0, op.outputs[1],
|
||||||
|
op.outputs[2],
|
||||||
|
op.get_attr("overlapping"))
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("FractionalAvgPool")
|
||||||
|
def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
|
||||||
|
"""Returns gradient for FractionalAvgPool.
|
||||||
|
|
||||||
|
Since FractionalAvgPool has three outputs, there are three gradients passed in
|
||||||
|
for each of the outputs. Only the first one is useful, the other two gradients
|
||||||
|
are empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: The FractionalAvgPoolOp.
|
||||||
|
grad_0: Gradient with respect to op.outputs[0]
|
||||||
|
unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
|
||||||
|
unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Input backprop for FractionalAvgPool op.
|
||||||
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
return gen_nn_ops._fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
|
||||||
|
op.outputs[1], op.outputs[2],
|
||||||
|
op.get_attr("overlapping"))
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("BatchNormWithGlobalNormalization")
|
@ops.RegisterGradient("BatchNormWithGlobalNormalization")
|
||||||
def _BatchNormWithGlobalNormalizationGrad(op, grad):
|
def _BatchNormWithGlobalNormalizationGrad(op, grad):
|
||||||
"""Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
|
"""Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
|
||||||
|
@ -941,6 +941,39 @@ def _AvgPoolGradShape(op):
|
|||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
return [tensor_shape.unknown_shape(ndims=4)]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("FractionalMaxPool")
|
||||||
|
@ops.RegisterShape("FractionalAvgPool")
|
||||||
|
def _fractional_pool_shape(op):
|
||||||
|
input_dims = op.inputs[0].get_shape().with_rank(4).as_list()
|
||||||
|
pooling_ratio = op.get_attr("pooling_ratio")
|
||||||
|
output_dims = np.divide(input_dims, pooling_ratio).astype(int)
|
||||||
|
return [
|
||||||
|
# output.
|
||||||
|
tensor_shape.TensorShape(output_dims),
|
||||||
|
# row_pooling_sequence.
|
||||||
|
tensor_shape.TensorShape([output_dims[1]]),
|
||||||
|
# col_pooling_sequence.
|
||||||
|
tensor_shape.TensorShape([output_dims[2]])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("FractionalMaxPoolGrad")
|
||||||
|
def _fractional_max_pool_grad_shape(op):
|
||||||
|
"""Shape function for the FractionalMaxPoolGrad op."""
|
||||||
|
orig_input_shape = op.inputs[0].get_shape().with_rank(4)
|
||||||
|
return [orig_input_shape]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("FractionalAvgPoolGrad")
|
||||||
|
def _fractional_avg_pool_grad_shape(op):
|
||||||
|
"""Shape function for the FractionalAvgPoolGrad op."""
|
||||||
|
orig_input_shape = tensor_util.constant_value(op.inputs[0])
|
||||||
|
if orig_input_shape is not None:
|
||||||
|
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
|
||||||
|
else:
|
||||||
|
return [tensor_shape.unknown_shape(ndims=4)]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Conv2DBackpropFilter")
|
@ops.RegisterShape("Conv2DBackpropFilter")
|
||||||
def _Conv2DBackpropFilterShape(op):
|
def _Conv2DBackpropFilterShape(op):
|
||||||
"""Shape function for the Conv2DBackpropFilter op."""
|
"""Shape function for the Conv2DBackpropFilter op."""
|
||||||
|
Loading…
Reference in New Issue
Block a user