Dynamic ksize and strides with MaxPool (#11875)
* Dynamic ksize with max_pool This fix tries to fix the issue raised in 4746 where ksize is static (attr) with max_pool. This fix changes ksize to input tensor so that it is dynamic now. This fix fixes 4746. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add dynamic ksize to MaxPoolGrad and MaxPoolGradGrad Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for max_pool_v2 Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix GPU Jenkins issue. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable MaxPoolV2 in GPU Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Hide MaxPoolV2 and other fixes. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
02d6bc185c
commit
98f0e1efec
tensorflow
core
framework
kernels
ops
python
@ -673,6 +673,116 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
|
||||
|
||||
string data_format;
|
||||
Status s = c->GetAttr("data_format", &data_format);
|
||||
|
||||
std::vector<int32> kernel_sizes;
|
||||
std::vector<int32> strides;
|
||||
|
||||
if (c->num_inputs() + 2 == num_inputs) {
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
||||
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
||||
} else {
|
||||
// Verify shape of ksize and strides input.
|
||||
ShapeHandle size;
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
|
||||
|
||||
const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
|
||||
if (kernel_sizes_tensor == nullptr) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
|
||||
auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
|
||||
std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin());
|
||||
|
||||
const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
|
||||
if (strides_tensor == nullptr) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
strides.resize(strides_tensor->shape().num_elements());
|
||||
auto strides_vec = strides_tensor->flat<int32>();
|
||||
std::copy_n(&strides_vec(0), strides.size(), strides.begin());
|
||||
}
|
||||
|
||||
if (strides.size() != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"MaxPool requires the stride attribute to contain 4 values, but "
|
||||
"got: ",
|
||||
strides.size());
|
||||
}
|
||||
if (kernel_sizes.size() != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"MaxPool requires the ksize attribute to contain 4 values, but got: ",
|
||||
kernel_sizes.size());
|
||||
}
|
||||
|
||||
int32 stride_rows, stride_cols, stride_depth;
|
||||
int32 kernel_rows, kernel_cols, kernel_depth;
|
||||
|
||||
if (s.ok() && data_format == "NCHW") {
|
||||
// Canonicalize input shape to NHWC so the shape inference code below can
|
||||
// process it.
|
||||
auto dim = [&](char dimension) {
|
||||
return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension));
|
||||
};
|
||||
input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}});
|
||||
stride_depth = strides[1];
|
||||
stride_rows = strides[2];
|
||||
stride_cols = strides[3];
|
||||
kernel_depth = kernel_sizes[1];
|
||||
kernel_rows = kernel_sizes[2];
|
||||
kernel_cols = kernel_sizes[3];
|
||||
} else {
|
||||
stride_rows = strides[1];
|
||||
stride_cols = strides[2];
|
||||
stride_depth = strides[3];
|
||||
kernel_rows = kernel_sizes[1];
|
||||
kernel_cols = kernel_sizes[2];
|
||||
kernel_depth = kernel_sizes[3];
|
||||
}
|
||||
|
||||
DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
|
||||
DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
|
||||
DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
|
||||
DimensionHandle in_depth_dim = c->Dim(input_shape, 3);
|
||||
|
||||
Padding padding;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
||||
|
||||
ShapeHandle output_shape;
|
||||
DimensionHandle output_rows, output_cols, output_depth;
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
||||
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
|
||||
|
||||
output_shape =
|
||||
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
|
||||
if (data_format == "NCHW") {
|
||||
// Convert output shape back to expected NCHW data format.
|
||||
auto dim = [&](char dimension) {
|
||||
return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension));
|
||||
};
|
||||
output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}});
|
||||
}
|
||||
|
||||
c->set_output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Pool3DShape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
|
||||
|
@ -179,6 +179,9 @@ Status AvgPoolShape(shape_inference::InferenceContext* c);
|
||||
// Shape function for MaxPool-like operations.
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for MaxPoolV2-like operations.
|
||||
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
|
||||
|
||||
// Shape function for 3D Pooling operations.
|
||||
Status Pool3DShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -208,22 +208,26 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ",
|
||||
"on device type ",
|
||||
DeviceTypeString(context->device_type())));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
|
||||
if (context->num_inputs() == 3) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, ksize_[3] == 1 && stride_[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, ksize_[3] == 1 && stride_[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -250,8 +254,35 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
|
||||
tensor_out.shape(),
|
||||
&tensor_out_arg_max));
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
if (context->num_inputs() == 5) {
|
||||
const Tensor& tensor_ksize = context->input(3);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
const Tensor& tensor_stride = context->input(4);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, ksize[3] == 1 && stride[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
@ -309,20 +340,22 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
if (context->num_inputs() == 3) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
use_dnn_ = CanUseCudnn();
|
||||
}
|
||||
@ -343,15 +376,40 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
|
||||
TensorShape output_shape = tensor_in.shape();
|
||||
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
if (context->num_inputs() == 5) {
|
||||
const Tensor& tensor_ksize = context->input(3);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(4);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
if (use_dnn_) {
|
||||
DnnPoolingGradOp<T>::Compute(
|
||||
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, data_format_, &tensor_in, &tensor_out,
|
||||
out_backprop, output_shape);
|
||||
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop,
|
||||
output_shape);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
|
||||
MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_,
|
||||
MaxPoolingBackwardCustomKernel<T>(context, ksize, stride, padding_,
|
||||
&tensor_in, out_backprop, output_shape);
|
||||
}
|
||||
}
|
||||
@ -386,22 +444,25 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Default MaxPoolingGradGradOp only supports NHWC ",
|
||||
"on device type ", DeviceTypeString(context->device_type())));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, ksize_[3] == 1 && stride_[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGradGrad is not yet supported on the depth dimension."));
|
||||
|
||||
if (context->num_inputs() == 3) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(context, ksize_[3] == 1 && stride_[3] == 1,
|
||||
errors::Unimplemented("MaxPoolingGradGrad is not yet "
|
||||
"supported on the depth dimension."));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -419,7 +480,35 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
context, out_grad_backprop.dims() == 4,
|
||||
errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
|
||||
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
if (context->num_inputs() == 5) {
|
||||
const Tensor& tensor_ksize = context->input(3);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(4);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
OP_REQUIRES(
|
||||
context, ksize[3] == 1 && stride[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPoolingGrad is not yet supported on the depth dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
@ -474,7 +563,7 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
// tensor_out_as_matrix with the corresponding values in
|
||||
// top_diff_as_matrix.
|
||||
auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
|
||||
int64 start, int64 limit) {
|
||||
int64 start, int64 limit) {
|
||||
const int32 depth = params.depth;
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
@ -555,20 +644,22 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
if (context->num_inputs() == 3) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -590,7 +681,33 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, tensor_out.shape(), &output));
|
||||
|
||||
PoolParameters params{context, ksize_, stride_,
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
if (context->num_inputs() == 5) {
|
||||
const Tensor& tensor_ksize = context->input(3);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(4);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
|
||||
functor::MaxPoolGradBackwardNoMask<T>()(
|
||||
@ -669,6 +786,84 @@ class MaxPoolingNoMaskOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingNoMaskV2Op : public OpKernel {
|
||||
public:
|
||||
explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(
|
||||
context, data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument(
|
||||
"Default MaxPoolingNoMaskOp only supports NHWC on device type ",
|
||||
DeviceTypeString(context->device_type())));
|
||||
if (context->num_inputs() == 1) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
|
||||
if (context->num_inputs() != 1) {
|
||||
const Tensor& tensor_ksize = context->input(1);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(2);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
TensorShape out_shape({params.tensor_in_batch, params.out_height,
|
||||
params.out_width, params.depth});
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
|
||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||
output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMaxPoolingWithArgmax;
|
||||
|
||||
@ -878,6 +1073,95 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
bool use_dnn_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
public:
|
||||
typedef GPUDevice Device;
|
||||
explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
if (context->num_inputs() == 1) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
use_dnn_ = CanUseCudnn();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
|
||||
if (context->num_inputs() != 1) {
|
||||
const Tensor& tensor_ksize = context->input(1);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(2);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
|
||||
const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
|
||||
params.out_width, params.depth);
|
||||
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(
|
||||
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, data_format_, tensor_in, out_shape);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "Non-Cudnn MaxPool only supports NHWC format";
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||
output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
bool use_dnn_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
@ -969,13 +1253,28 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
||||
MaxPoolingGradOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
MaxPoolingGradGradOp<D##Device, T>);
|
||||
MaxPoolingGradGradOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradV2") \
|
||||
.Device(DEVICE_##D) \
|
||||
.HostMemory("ksize") \
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingGradOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradV2") \
|
||||
.Device(DEVICE_##D) \
|
||||
.HostMemory("ksize") \
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingGradGradOp<D##Device, T>);
|
||||
|
||||
// Below kernels implemented only for CPU device.
|
||||
#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingOp<CPUDevice, T>);
|
||||
#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingV2Op<CPUDevice, T>);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
|
||||
#undef REGISTER_CPU_ONLY_POOL_KERNELS
|
||||
|
||||
@ -1015,9 +1314,22 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("eigen_tensor"), \
|
||||
MaxPoolingOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("ksize") \
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("eigen_tensor"), \
|
||||
MaxPoolingV2Op<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingNoMaskOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("ksize") \
|
||||
.HostMemory("strides") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingNoMaskV2Op<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
|
@ -69,6 +69,8 @@ struct PoolParameters {
|
||||
};
|
||||
|
||||
// An implementation of MaxPooling (forward).
|
||||
// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
|
||||
// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingOp : public OpKernel {
|
||||
public:
|
||||
@ -254,6 +256,219 @@ class MaxPoolingOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingV2Op : public OpKernel {
|
||||
public:
|
||||
explicit MaxPoolingV2Op(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string data_format;
|
||||
auto status = context->GetAttr("data_format", &data_format);
|
||||
if (status.ok()) {
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(
|
||||
context, data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Default MaxPoolingOp only supports NHWC."));
|
||||
} else {
|
||||
data_format_ = FORMAT_NHWC;
|
||||
}
|
||||
if (context->num_inputs() == 1) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(context, stride_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& tensor_in = context->input(0);
|
||||
|
||||
std::vector<int32> ksize = ksize_;
|
||||
std::vector<int32> stride = stride_;
|
||||
|
||||
if (context->num_inputs() != 1) {
|
||||
const Tensor& tensor_ksize = context->input(1);
|
||||
auto value_ksize = tensor_ksize.flat<int32>();
|
||||
ksize.resize(tensor_ksize.shape().num_elements());
|
||||
std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
|
||||
|
||||
const Tensor& tensor_stride = context->input(2);
|
||||
auto value_stride = tensor_stride.flat<int32>();
|
||||
stride.resize(tensor_stride.shape().num_elements());
|
||||
std::copy_n(&value_stride(0), stride.size(), stride.begin());
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, ksize.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, stride.size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, params.forward_output_shape(), &output));
|
||||
|
||||
if (params.depth_window > 1) {
|
||||
// Validate spec against the current implementation. A
|
||||
// relaxation of these requirements would be ideal.
|
||||
OP_REQUIRES(context, params.depth % params.depth_window == 0,
|
||||
errors::Unimplemented(
|
||||
"Depthwise max pooling requires "
|
||||
"the depth window to evenly divide the input depth."));
|
||||
OP_REQUIRES(
|
||||
context, params.depth_window == params.depth_stride,
|
||||
errors::Unimplemented("Depthwise max pooling requires "
|
||||
"the depth window to equal the depth stride."));
|
||||
|
||||
DepthwiseMaxPool(context, output, tensor_in, params);
|
||||
} else {
|
||||
SpatialMaxPool(context, output, tensor_in, params, padding_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Single-threaded implementation of DepthwiseMaxPool which
|
||||
// does not handle all of the same options as SpatialMaxPool
|
||||
// (strict assumptions on no padding, stride).
|
||||
//
|
||||
// TODO(vrv): implement a more general depthwise-max pool that works
|
||||
// on GPU as well.
|
||||
void DepthwiseMaxPool(OpKernelContext* context, Tensor* output,
|
||||
const Tensor& tensor_in, const PoolParameters& params) {
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
in_by_pool(tensor_in.flat<T>().data(), params.depth_window,
|
||||
tensor_in.NumElements() / params.depth_window);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool(
|
||||
output->flat<T>().data(), 1, output->NumElements());
|
||||
out_by_pool = in_by_pool.colwise().maxCoeff();
|
||||
}
|
||||
|
||||
void SpatialMaxPool(OpKernelContext* context, Tensor* output,
|
||||
const Tensor& tensor_in, const PoolParameters& params,
|
||||
const Padding& padding) {
|
||||
// On GPU, use Eigen's Spatial Max Pooling. On CPU, use an
|
||||
// EigenMatrix version that is currently faster than Eigen's
|
||||
// Spatial MaxPooling implementation.
|
||||
//
|
||||
// TODO(vrv): Remove this once we no longer need it.
|
||||
if (std::is_same<Device, GPUDevice>::value) {
|
||||
Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
|
||||
functor::SpatialMaxPooling<Device, T>()(
|
||||
context->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
|
||||
params.row_stride, params.col_stride, pt);
|
||||
} else {
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
ConstEigenMatrixMap;
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
EigenMatrixMap;
|
||||
|
||||
ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
|
||||
params.tensor_in_cols * params.tensor_in_rows *
|
||||
params.tensor_in_batch);
|
||||
EigenMatrixMap out_mat(
|
||||
output->flat<T>().data(), params.depth,
|
||||
params.out_width * params.out_height * params.tensor_in_batch);
|
||||
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
// The following code basically does the following:
|
||||
// 1. Flattens the input and output tensors into two dimensional arrays.
|
||||
// tensor_in_as_matrix:
|
||||
// depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
|
||||
// output_as_matrix:
|
||||
// depth by (out_width * out_height * tensor_in_batch)
|
||||
//
|
||||
// 2. Walks through the set of columns in the flattened
|
||||
// tensor_in_as_matrix,
|
||||
// and updates the corresponding column(s) in output_as_matrix with the
|
||||
// max value.
|
||||
auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) {
|
||||
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
const int32 pad_rows = params.pad_rows;
|
||||
const int32 pad_cols = params.pad_cols;
|
||||
const int32 window_rows = params.window_rows;
|
||||
const int32 window_cols = params.window_cols;
|
||||
const int32 row_stride = params.row_stride;
|
||||
const int32 col_stride = params.col_stride;
|
||||
const int32 out_height = params.out_height;
|
||||
const int32 out_width = params.out_width;
|
||||
|
||||
{
|
||||
// Initializes the output tensor with MIN<T>.
|
||||
const int32 output_image_size = out_height * out_width * params.depth;
|
||||
EigenMatrixMap out_shard(out_mat.data() + start * output_image_size,
|
||||
1, (limit - start) * output_image_size);
|
||||
out_shard.setConstant(Eigen::NumTraits<T>::lowest());
|
||||
}
|
||||
|
||||
for (int32 b = start; b < limit; ++b) {
|
||||
const int32 out_offset_batch = b * out_height;
|
||||
for (int32 h = 0; h < in_rows; ++h) {
|
||||
for (int32 w = 0; w < in_cols; ++w) {
|
||||
// (h_start, h_end) * (w_start, w_end) is the range that the input
|
||||
// vector projects to.
|
||||
const int32 hpad = h + pad_rows;
|
||||
const int32 wpad = w + pad_cols;
|
||||
const int32 h_start = (hpad < window_rows)
|
||||
? 0
|
||||
: (hpad - window_rows) / row_stride + 1;
|
||||
const int32 h_end = std::min(hpad / row_stride + 1, out_height);
|
||||
const int32 w_start = (wpad < window_cols)
|
||||
? 0
|
||||
: (wpad - window_cols) / col_stride + 1;
|
||||
const int32 w_end = std::min(wpad / col_stride + 1, out_width);
|
||||
// compute elementwise max
|
||||
const int32 in_offset = (b * in_rows + h) * in_cols + w;
|
||||
for (int32 ph = h_start; ph < h_end; ++ph) {
|
||||
const int32 out_offset_base =
|
||||
(out_offset_batch + ph) * out_width;
|
||||
for (int32 pw = w_start; pw < w_end; ++pw) {
|
||||
const int32 out_offset = out_offset_base + pw;
|
||||
out_mat.col(out_offset) =
|
||||
out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(andydavis) Consider sharding across batch x rows x cols.
|
||||
// TODO(andydavis) Consider a higher resolution shard cost model.
|
||||
const int64 shard_cost =
|
||||
params.tensor_in_rows * params.tensor_in_cols * params.depth;
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
params.tensor_in_batch, shard_cost, shard);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void SpatialAvgPool(OpKernelContext* context, Tensor* output,
|
||||
const Tensor& input, const PoolParameters& params,
|
||||
|
@ -1368,6 +1368,34 @@ input: 4-D input to pool over.
|
||||
output: The max pooled output tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolV2")
|
||||
.Attr("T: realnumbertype = DT_FLOAT")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("input: T")
|
||||
.Input("ksize: int32")
|
||||
.Input("strides: int32")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Performs max pooling on the input.
|
||||
|
||||
ksize: The size of the window for each dimension of the input tensor.
|
||||
strides: The stride of the sliding window for each dimension of the
|
||||
input tensor.
|
||||
padding: The type of padding algorithm to use.
|
||||
data_format: Specify the data format of the input and output data. With the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, in_height, in_width, in_channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, in_channels, in_height, in_width].
|
||||
input: 4-D input to pool over.
|
||||
output: The max pooled output tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolGrad")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
@ -1399,6 +1427,37 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`.
|
||||
output: Gradients w.r.t. the input to `max_pool`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolGradV2")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("orig_input: T")
|
||||
.Input("orig_output: T")
|
||||
.Input("grad: T")
|
||||
.Input("ksize: int32")
|
||||
.Input("strides: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return UnchangedShapeWithRank(c, 4);
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Computes gradients of the maxpooling function.
|
||||
|
||||
ksize: The size of the window for each dimension of the input tensor.
|
||||
strides: The stride of the sliding window for each dimension of the
|
||||
input tensor.
|
||||
padding: The type of padding algorithm to use.
|
||||
data_format: Specify the data format of the input and output data. With the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, in_height, in_width, in_channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, in_channels, in_height, in_width].
|
||||
orig_input: The original input tensor.
|
||||
orig_output: The original output tensor.
|
||||
grad: 4-D. Gradients w.r.t. the output of `max_pool`.
|
||||
output: Gradients w.r.t. the input to `max_pool`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolGradGrad")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
@ -1436,6 +1495,43 @@ grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
|
||||
output: Gradients of gradients w.r.t. the input to `max_pool`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolGradGradV2")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("orig_input: T")
|
||||
.Input("orig_output: T")
|
||||
.Input("grad: T")
|
||||
.Input("ksize: int32")
|
||||
.Input("strides: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
|
||||
ShapeHandle unused;
|
||||
// Validate 'orig_input' is the same shape as 'grad'
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
|
||||
// Validate 'orig_output' is same shape as 'output'
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Computes second-order gradients of the maxpooling function.
|
||||
|
||||
ksize: The size of the window for each dimension of the input tensor.
|
||||
strides: The stride of the sliding window for each dimension of the
|
||||
input tensor.
|
||||
padding: The type of padding algorithm to use.
|
||||
data_format: Specify the data format of the input and output data. With the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, in_height, in_width, in_channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, in_channels, in_height, in_width].
|
||||
orig_input: The original input tensor.
|
||||
orig_output: The original output tensor.
|
||||
grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
|
||||
output: Gradients of gradients w.r.t. the input to `max_pool`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolWithArgmax")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.framework import ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -76,7 +77,7 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
|
||||
class PoolingTest(test.TestCase):
|
||||
|
||||
def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, data_type, expected, use_gpu):
|
||||
data_format, data_type, expected, use_gpu, v2):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -103,20 +104,35 @@ class PoolingTest(test.TestCase):
|
||||
t = test_util.NHWCToNCHW(t)
|
||||
ksize = test_util.NHWCToNCHW(ksize)
|
||||
strides = test_util.NHWCToNCHW(strides)
|
||||
t = pool_func(
|
||||
t,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
v2 = v2 and data_format != "NCHW"
|
||||
ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
if v2:
|
||||
t = pool_func(
|
||||
t,
|
||||
ksize=ksize_placeholder,
|
||||
strides=strides_placeholder,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
else:
|
||||
t = pool_func(
|
||||
t,
|
||||
ksize=ksize,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
if data_format == "NCHW":
|
||||
t = test_util.NCHWToNHWC(t)
|
||||
actual = t.eval()
|
||||
if v2:
|
||||
actual = t.eval(feed_dict={ksize_placeholder: ksize,
|
||||
strides_placeholder: strides})
|
||||
else:
|
||||
actual = t.eval()
|
||||
self.assertShapeEqual(actual, t)
|
||||
self.assertAllCloseAccordingToType(expected, actual.flatten())
|
||||
self.assertShapeEqual(actual, t)
|
||||
|
||||
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected, use_gpu):
|
||||
data_format, expected, use_gpu, v2):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -131,14 +147,14 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float32, expected, use_gpu)
|
||||
data_format, dtypes.float32, expected, use_gpu, v2)
|
||||
|
||||
if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv():
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float16, expected, use_gpu)
|
||||
data_format, dtypes.float16, expected, use_gpu, v2)
|
||||
|
||||
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
expected, use_gpu):
|
||||
expected, use_gpu, v2=False):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
@ -154,7 +170,7 @@ class PoolingTest(test.TestCase):
|
||||
for (data_format, use_gpu_2) in GetTestConfigs():
|
||||
if use_gpu_2 == use_gpu:
|
||||
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected, use_gpu)
|
||||
data_format, expected, use_gpu, v2)
|
||||
|
||||
def _testAvgPoolValidPadding(self, use_gpu):
|
||||
expected_output = [7.0, 8.0, 9.0]
|
||||
@ -325,6 +341,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 3, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="VALID",
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolSamePadding(self, use_gpu):
|
||||
expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
|
||||
self._VerifyValues(
|
||||
@ -336,6 +363,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 2, 3, 3],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
|
||||
# input is:
|
||||
# [1.0, 2.0
|
||||
@ -354,6 +392,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=[2.0, 2.0, 4.0, 4.0],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 2, 2, 1],
|
||||
ksize=[1, 1, 2, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
padding="SAME",
|
||||
expected=[2.0, 2.0, 4.0, 4.0],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolValidPaddingUnevenStride(self, use_gpu):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
@ -372,6 +421,26 @@ class PoolingTest(test.TestCase):
|
||||
expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 1, 2, 1],
|
||||
padding="VALID",
|
||||
expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 1, 1],
|
||||
padding="VALID",
|
||||
expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolSamePaddingPacket4(self, use_gpu):
|
||||
expected_output = [
|
||||
21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
|
||||
@ -386,6 +455,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 4, 4, 4],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolSamePaddingPacket8(self, use_gpu):
|
||||
expected_output = [
|
||||
145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
|
||||
@ -411,6 +491,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 8, 8, 8],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def testMaxPooling(self):
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolValidPadding(use_gpu)
|
||||
@ -435,6 +526,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=[2.0, 4.0, 6.0, 8.0, 10.0],
|
||||
use_gpu=False)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 1, 1, 10],
|
||||
ksize=[1, 1, 1, 2],
|
||||
strides=[1, 1, 1, 2],
|
||||
padding="SAME",
|
||||
expected=[2.0, 4.0, 6.0, 8.0, 10.0],
|
||||
use_gpu=False,
|
||||
v2=v2)
|
||||
|
||||
def testDepthwiseMaxPool2x2DepthWindow3(self):
|
||||
# input is:
|
||||
#
|
||||
@ -450,6 +552,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
|
||||
use_gpu=False)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 2, 2, 6],
|
||||
ksize=[1, 1, 1, 3],
|
||||
strides=[1, 1, 1, 3],
|
||||
padding="SAME",
|
||||
expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
|
||||
use_gpu=False,
|
||||
v2=v2)
|
||||
|
||||
def testKernelSmallerThanStrideValid(self):
|
||||
for use_gpu in [True, False]:
|
||||
self._VerifyValues(
|
||||
@ -461,6 +574,17 @@ class PoolingTest(test.TestCase):
|
||||
expected=[9, 12, 30, 33],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 3, 3, 1],
|
||||
padding="VALID",
|
||||
expected=[9, 12, 30, 33],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
self._VerifyValues(
|
||||
nn_ops.avg_pool,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
@ -491,6 +615,27 @@ class PoolingTest(test.TestCase):
|
||||
expected=[1, 3, 9, 11],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
for v2 in [True, False]:
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[1, 3, 7, 9],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[1, 3, 9, 11],
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testDepthwiseMaxPoolInvalidConfig(self,
|
||||
in_size,
|
||||
ksize,
|
||||
@ -812,99 +957,107 @@ class PoolingTest(test.TestCase):
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradValidPadding2_1_6(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 6, 6, 3],
|
||||
output_sizes=[2, 5, 5, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 6, 6, 3],
|
||||
output_sizes=[2, 5, 5, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradValidPadding2_1_7(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 7, 7, 3],
|
||||
output_sizes=[2, 6, 6, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 7, 7, 3],
|
||||
output_sizes=[2, 6, 6, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradValidPadding2_2(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[2, 1, 1, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[2, 1, 1, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradSamePadding1_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradSamePadding2_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradSamePadding2_2(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 1, 2, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 1, 2, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestGradient(
|
||||
nn_ops.max_pool,
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
window_rows=3,
|
||||
@ -927,7 +1080,7 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolGradSamePadding3_1(data_format, use_gpu)
|
||||
|
||||
def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
|
||||
window_cols, row_stride, col_stride, padding):
|
||||
window_cols, row_stride, col_stride, padding, v2):
|
||||
"""Max Pooling Gradient.
|
||||
|
||||
Args:
|
||||
@ -944,26 +1097,29 @@ class PoolingTest(test.TestCase):
|
||||
Returns:
|
||||
A Tensor.
|
||||
"""
|
||||
return gen_nn_ops._max_pool_grad(orig_input, orig_output, grad,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops._max_pool_grad
|
||||
return pool_func(orig_input, orig_output, grad,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
|
||||
def _testMaxPoolGradDirect(self, input_data, output_backprop,
|
||||
expected_input_backprop, input_sizes, output_sizes,
|
||||
window_rows, window_cols, row_stride, col_stride,
|
||||
padding, use_gpu):
|
||||
padding, use_gpu, v2):
|
||||
pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
input_tensor = constant_op.constant(input_data, shape=input_sizes)
|
||||
output_tensor = nn_ops.max_pool(input_tensor,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
output_tensor = pool_func(input_tensor,
|
||||
[1, window_rows, window_cols, 1],
|
||||
[1, row_stride, col_stride, 1], padding)
|
||||
output_backprop_tensor = constant_op.constant(
|
||||
output_backprop, shape=output_sizes)
|
||||
|
||||
input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor,
|
||||
output_backprop_tensor,
|
||||
window_rows, window_cols,
|
||||
row_stride, col_stride, padding)
|
||||
row_stride, col_stride,
|
||||
padding, v2)
|
||||
|
||||
actual_input_backprop = input_backprop_tensor.eval()
|
||||
self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
|
||||
@ -988,18 +1144,20 @@ class PoolingTest(test.TestCase):
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradDirect1_2(self):
|
||||
input_data = [
|
||||
@ -1013,18 +1171,20 @@ class PoolingTest(test.TestCase):
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradDirect1_3(self):
|
||||
input_data = [
|
||||
@ -1069,18 +1229,20 @@ class PoolingTest(test.TestCase):
|
||||
]
|
||||
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 4, 4, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 4, 4, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradDirectWithNans2_1(self):
|
||||
input_data = [float("nan")] * 16
|
||||
@ -1090,18 +1252,20 @@ class PoolingTest(test.TestCase):
|
||||
11.0, 12.0, 13.0, 0.0, 15.0, 16.0, 17.0, 0.0, 19.0, 20.0, 21.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0
|
||||
]
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_tf_cpu,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=False)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_tf_cpu,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=False,
|
||||
v2=v2)
|
||||
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
@ -1112,18 +1276,20 @@ class PoolingTest(test.TestCase):
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0
|
||||
]
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_cudnn,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=True)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_cudnn,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=True,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolGradDirectWithNans2_2(self):
|
||||
input_data = [float("nan")] * 16
|
||||
@ -1136,18 +1302,20 @@ class PoolingTest(test.TestCase):
|
||||
float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0,
|
||||
20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
]
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_tf_cpu,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=False)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_tf_cpu,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=False,
|
||||
v2=v2)
|
||||
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
@ -1158,18 +1326,20 @@ class PoolingTest(test.TestCase):
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0
|
||||
]
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_cudnn,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=True)
|
||||
for v2 in [True, False]:
|
||||
self._testMaxPoolGradDirect(
|
||||
input_data,
|
||||
output_backprop,
|
||||
expected_input_backprop_cudnn,
|
||||
input_sizes=[1, 4, 4, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
use_gpu=True,
|
||||
v2=v2)
|
||||
|
||||
def testMaxPoolGradDirect(self):
|
||||
self._testMaxPoolGradDirect1_1()
|
||||
@ -1179,108 +1349,116 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolGradDirectWithNans2_2()
|
||||
|
||||
def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 3, 3, 1],
|
||||
output_sizes=[1, 3, 3, 1],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradValidPadding2_1_6(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 6, 6, 3],
|
||||
output_sizes=[2, 5, 5, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 6, 6, 3],
|
||||
output_sizes=[2, 5, 5, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradValidPadding2_1_7(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 7, 7, 3],
|
||||
output_sizes=[2, 6, 6, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 7, 7, 3],
|
||||
output_sizes=[2, 6, 6, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradValidPadding2_2(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[2, 1, 1, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 2, 3],
|
||||
output_sizes=[2, 1, 1, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="VALID",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradSamePadding1_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=1,
|
||||
window_cols=1,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradSamePadding2_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 2, 4, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradSamePadding2_2(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 1, 2, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[2, 2, 4, 3],
|
||||
output_sizes=[2, 1, 2, 3],
|
||||
window_rows=2,
|
||||
window_cols=2,
|
||||
row_stride=2,
|
||||
col_stride=2,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testMaxPoolGradGradSamePadding3_1(self, data_format, use_gpu):
|
||||
self._ConstructAndTestSecondGradient(
|
||||
nn_ops.max_pool,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
|
||||
self._ConstructAndTestSecondGradient(
|
||||
pool_func,
|
||||
input_sizes=[1, 7, 7, 1],
|
||||
output_sizes=[1, 7, 7, 1],
|
||||
window_rows=3,
|
||||
window_cols=3,
|
||||
row_stride=1,
|
||||
col_stride=1,
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testMaxPoolGradGrad(self):
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
|
@ -302,6 +302,7 @@ BiasAddV1
|
||||
Relu6
|
||||
AvgPool
|
||||
MaxPool
|
||||
MaxPoolV2
|
||||
Softmax
|
||||
LogSoftmax
|
||||
FractionalAvgPoolGrad
|
||||
|
@ -541,6 +541,19 @@ def _MaxPoolGrad(op, grad):
|
||||
data_format=op.get_attr("data_format"))
|
||||
|
||||
|
||||
@ops.RegisterGradient("MaxPoolV2")
|
||||
def _MaxPoolGradV2(op, grad):
|
||||
ksize = op.inputs[1]
|
||||
strides = op.inputs[2]
|
||||
return gen_nn_ops.max_pool_grad_v2(op.inputs[0],
|
||||
op.outputs[0],
|
||||
grad,
|
||||
ksize,
|
||||
strides,
|
||||
padding=op.get_attr("padding"),
|
||||
data_format=op.get_attr("data_format")), None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("MaxPoolWithArgmax")
|
||||
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
|
||||
return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
|
||||
@ -567,6 +580,24 @@ def _MaxPoolGradGrad(op, grad):
|
||||
data_format=op.get_attr("data_format")))
|
||||
|
||||
|
||||
@ops.RegisterGradient("MaxPoolGradV2")
|
||||
def _MaxPoolGradGradV2(op, grad):
|
||||
ksize = op.inputs[3]
|
||||
strides = op.inputs[4]
|
||||
return (array_ops.zeros(
|
||||
shape=array_ops.shape(op.inputs[0]),
|
||||
dtype=op.inputs[0].dtype), array_ops.zeros(
|
||||
shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
|
||||
gen_nn_ops.max_pool_grad_grad_v2(
|
||||
op.inputs[0],
|
||||
op.inputs[1],
|
||||
grad,
|
||||
ksize,
|
||||
strides,
|
||||
padding=op.get_attr("padding"),
|
||||
data_format=op.get_attr("data_format")), None, None)
|
||||
|
||||
|
||||
@ops.RegisterGradient("MaxPoolGradGrad")
|
||||
def _MaxPoolGradGradGrad(op, grad):
|
||||
return (array_ops.zeros(
|
||||
|
Loading…
Reference in New Issue
Block a user