Morphological filtering operations: Dilation and erosion.
Change: 123935296
This commit is contained in:
parent
1c54034106
commit
283782a2dc
@ -1227,6 +1227,7 @@ tf_kernel_libraries(
|
||||
":conv_ops",
|
||||
":depthwise_conv_grad_op",
|
||||
":depthwise_conv_op",
|
||||
":dilation_ops",
|
||||
":ops_util",
|
||||
":pooling_ops",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1347,6 +1348,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dilation_ops",
|
||||
prefix = "dilation_ops",
|
||||
deps = [
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "batchtospace_op",
|
||||
prefix = "batchtospace_op",
|
||||
|
491
tensorflow/core/kernels/dilation_ops.cc
Normal file
491
tensorflow/core/kernels/dilation_ops.cc
Normal file
@ -0,0 +1,491 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/kernels/dilation_ops.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
void ParseAttributes(OpKernelConstruction* context, std::vector<int32>* strides,
|
||||
std::vector<int32>* rates, Padding* padding) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", strides));
|
||||
OP_REQUIRES(context, strides->size() == 4,
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify 4 dimensions"));
|
||||
OP_REQUIRES(context, (*strides)[0] == 1 && (*strides)[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"Stride is only supported across spatial dimensions."));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("rates", rates));
|
||||
OP_REQUIRES(context, rates->size() == 4,
|
||||
errors::InvalidArgument("Input stride (atrous rate) field "
|
||||
"must specify 4 dimensions"));
|
||||
OP_REQUIRES(context, (*rates)[0] == 1 && (*rates)[3] == 1,
|
||||
errors::Unimplemented(
|
||||
"Rate is only supported across spatial dimensions."));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", padding));
|
||||
}
|
||||
|
||||
void ParseSizes(OpKernelContext* context, const std::vector<int32>& strides,
|
||||
const std::vector<int32>& rates, const Padding& padding,
|
||||
int* stride_rows, int* stride_cols, int* rate_rows,
|
||||
int* rate_cols, int* pad_top, int* pad_left, int* out_rows,
|
||||
int* out_cols) {
|
||||
// Input tensor is of the following dimensions:
|
||||
// [ batch, input_rows, input_cols, depth ]
|
||||
const Tensor& input = context->input(0);
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
const int input_rows = input.dim_size(1);
|
||||
const int input_cols = input.dim_size(2);
|
||||
const int depth = input.dim_size(3);
|
||||
|
||||
// For now we take the stride and rate from the second and third dimensions
|
||||
// only (we do not support striding on the batch or depth dimension).
|
||||
*stride_rows = strides[1];
|
||||
*stride_cols = strides[2];
|
||||
*rate_rows = rates[1];
|
||||
*rate_cols = rates[2];
|
||||
|
||||
// Input filter is of the following dimensions:
|
||||
// [ filter_rows, filter_cols, depth ]
|
||||
const Tensor& filter = context->input(1);
|
||||
OP_REQUIRES(context, filter.dims() == 3,
|
||||
errors::InvalidArgument("filter must be 3-dimensional: ",
|
||||
filter.shape().DebugString()));
|
||||
const int filter_rows = filter.dim_size(0);
|
||||
const int filter_cols = filter.dim_size(1);
|
||||
OP_REQUIRES(
|
||||
context, depth == filter.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
depth, " vs ", filter.dim_size(2)));
|
||||
|
||||
// Effective filter size, after introducing rate - 1 zeros between each
|
||||
// non-zero filter element.
|
||||
const int filter_rows_eff =
|
||||
filter_rows + (filter_rows - 1) * (*rate_rows - 1);
|
||||
const int filter_cols_eff =
|
||||
filter_cols + (filter_cols - 1) * (*rate_cols - 1);
|
||||
|
||||
int pad_bottom = 0, pad_right = 0;
|
||||
OP_REQUIRES_OK(context,
|
||||
Get2dOutputSizeVerbose(
|
||||
input_rows, input_cols, filter_rows_eff, filter_cols_eff,
|
||||
*stride_rows, *stride_cols, padding, out_rows, out_cols,
|
||||
pad_top, &pad_bottom, pad_left, &pad_right));
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DilationOp : public OpKernel {
|
||||
public:
|
||||
explicit DilationOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
ParseAttributes(context, &strides_, &rates_, &padding_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
|
||||
// Determine relevant sizes from input and filters.
|
||||
int stride_rows = 0, stride_cols = 0;
|
||||
int rate_rows = 0, rate_cols = 0;
|
||||
int pad_top = 0, pad_left = 0;
|
||||
int out_rows = 0, out_cols = 0;
|
||||
ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
|
||||
&rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
|
||||
&out_cols);
|
||||
|
||||
// Output tensor is of the following dimensions:
|
||||
// [ batch, out_rows, out_cols, depth ]
|
||||
const int batch = input.dim_size(0);
|
||||
const int depth = input.dim_size(3);
|
||||
const std::vector<int64> out_sizes = {batch, out_rows, out_cols, depth};
|
||||
TensorShape out_shape(out_sizes);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (out_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
functor::Dilation<Device, T>()(
|
||||
context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
filter.tensor<T, 3>(), stride_rows, stride_cols, rate_rows, rate_cols,
|
||||
pad_top, pad_left, output->tensor<T, 4>());
|
||||
}
|
||||
|
||||
std::vector<int32> strides_;
|
||||
std::vector<int32> rates_;
|
||||
Padding padding_;
|
||||
};
|
||||
|
||||
// Partial specialization of Dilation functor for a CPUDevice.
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct Dilation<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter, int stride_rows,
|
||||
int stride_cols, int rate_rows, int rate_cols, int pad_top,
|
||||
int pad_left, typename TTypes<T, 4>::Tensor output) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = output.dimension(1);
|
||||
const int output_cols = output.dimension(2);
|
||||
|
||||
// This is a reference implementation, likely to be slow.
|
||||
// TODO(gpapan): Write multi-threaded implementation.
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int h_out = 0; h_out < output_rows; ++h_out) {
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
for (int w_out = 0; w_out < output_cols; ++w_out) {
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val = input(b, h_in, w_in, d) + filter(h, w, d);
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output(b, h_out, w_out, d) = cur_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DilationBackpropInputOp : public OpKernel {
|
||||
public:
|
||||
explicit DilationBackpropInputOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
ParseAttributes(context, &strides_, &rates_, &padding_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
|
||||
// Determine relevant sizes from input and filters.
|
||||
int stride_rows = 0, stride_cols = 0;
|
||||
int rate_rows = 0, rate_cols = 0;
|
||||
int pad_top = 0, pad_left = 0;
|
||||
int out_rows = 0, out_cols = 0;
|
||||
ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
|
||||
&rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
|
||||
&out_cols);
|
||||
|
||||
// Verify that the incoming gradient tensor has the expected size
|
||||
// [ batch, out_rows, out_cols, depth ]
|
||||
const int batch = input.dim_size(0);
|
||||
const int depth = input.dim_size(3);
|
||||
OP_REQUIRES(context, batch == out_backprop.dim_size(0) &&
|
||||
out_rows == out_backprop.dim_size(1) &&
|
||||
out_cols == out_backprop.dim_size(2) &&
|
||||
depth == out_backprop.dim_size(3),
|
||||
errors::InvalidArgument("out_backprop has incompatible size."));
|
||||
|
||||
// The computed in_backprop has the same dimensions as the input:
|
||||
// [ batch, input_rows, input_cols, depth ]
|
||||
Tensor* in_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &in_backprop));
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (input.shape().num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
functor::DilationBackpropInput<Device, T>()(
|
||||
context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
|
||||
stride_cols, rate_rows, rate_cols, pad_top, pad_left,
|
||||
in_backprop->tensor<T, 4>());
|
||||
}
|
||||
|
||||
std::vector<int32> strides_;
|
||||
std::vector<int32> rates_;
|
||||
Padding padding_;
|
||||
};
|
||||
|
||||
// Partial specialization of DilationBackpropInput functor for a CPUDevice.
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct DilationBackpropInput<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 4>::Tensor in_backprop) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = out_backprop.dimension(1);
|
||||
const int output_cols = out_backprop.dimension(2);
|
||||
|
||||
// Initialize gradient with all zeros.
|
||||
in_backprop.setZero();
|
||||
|
||||
// This is a reference implementation, likely to be slow.
|
||||
// TODO(gpapan): Write multi-threaded implementation.
|
||||
// In the case of multiple argmax branches, we only back-propagate along the
|
||||
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
||||
// similarly to the max-pooling backward routines.
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int h_out = 0; h_out < output_rows; ++h_out) {
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
for (int w_out = 0; w_out < output_cols; ++w_out) {
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
int h_in_max = (h_beg < 0) ? 0 : h_beg;
|
||||
int w_in_max = (w_beg < 0) ? 0 : w_beg;
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val = input(b, h_in, w_in, d) + filter(h, w, d);
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
h_in_max = h_in;
|
||||
w_in_max = w_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
in_backprop(b, h_in_max, w_in_max, d) +=
|
||||
out_backprop(b, h_out, w_out, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DilationBackpropFilterOp : public OpKernel {
|
||||
public:
|
||||
explicit DilationBackpropFilterOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
ParseAttributes(context, &strides_, &rates_, &padding_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
|
||||
// Determine relevant sizes from input and filters.
|
||||
int stride_rows = 0, stride_cols = 0;
|
||||
int rate_rows = 0, rate_cols = 0;
|
||||
int pad_top = 0, pad_left = 0;
|
||||
int out_rows = 0, out_cols = 0;
|
||||
ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols,
|
||||
&rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows,
|
||||
&out_cols);
|
||||
|
||||
// Verify that the incoming gradient tensor has the expected size
|
||||
// [ batch, out_rows, out_cols, depth ]
|
||||
const int batch = input.dim_size(0);
|
||||
const int depth = input.dim_size(3);
|
||||
OP_REQUIRES(context, batch == out_backprop.dim_size(0) &&
|
||||
out_rows == out_backprop.dim_size(1) &&
|
||||
out_cols == out_backprop.dim_size(2) &&
|
||||
depth == out_backprop.dim_size(3),
|
||||
errors::InvalidArgument("out_backprop has incompatible size."));
|
||||
|
||||
// The computed filter_backprop has the same dimensions as the filter:
|
||||
// [ batch, input_rows, input_cols, depth ]
|
||||
Tensor* filter_backprop = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(0, filter.shape(), &filter_backprop));
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (filter.shape().num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
functor::DilationBackpropFilter<Device, T>()(
|
||||
context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
filter.tensor<T, 3>(), out_backprop.tensor<T, 4>(), stride_rows,
|
||||
stride_cols, rate_rows, rate_cols, pad_top, pad_left,
|
||||
filter_backprop->tensor<T, 3>());
|
||||
}
|
||||
|
||||
std::vector<int32> strides_;
|
||||
std::vector<int32> rates_;
|
||||
Padding padding_;
|
||||
};
|
||||
|
||||
// Partial specialization of DilationBackpropFilter functor for a CPUDevice.
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct DilationBackpropFilter<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 3>::Tensor filter_backprop) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = out_backprop.dimension(1);
|
||||
const int output_cols = out_backprop.dimension(2);
|
||||
|
||||
// Initialize gradient with all zeros.
|
||||
filter_backprop.setZero();
|
||||
|
||||
// This is a reference implementation, likely to be slow.
|
||||
// TODO(gpapan): Write multi-threaded implementation.
|
||||
// In the case of multiple argmax branches, we only back-propagate along the
|
||||
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
||||
// similarly to the max-pooling backward routines.
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int h_out = 0; h_out < output_rows; ++h_out) {
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
for (int w_out = 0; w_out < output_cols; ++w_out) {
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
int h_max = 0;
|
||||
int w_max = 0;
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val = input(b, h_in, w_in, d) + filter(h, w, d);
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
h_max = h;
|
||||
w_max = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
filter_backprop(h_max, w_max, d) +=
|
||||
out_backprop(b, h_out, w_out, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Dilation2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
DilationOp<CPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
DilationBackpropInputOp<CPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
DilationBackpropFilterOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Dilation2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
DilationOp<GPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
DilationBackpropInputOp<GPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
DilationBackpropFilterOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
66
tensorflow/core/kernels/dilation_ops.h
Normal file
66
tensorflow/core/kernels/dilation_ops.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct Dilation {
|
||||
// We assume that the tensor sizes are correct.
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter, int stride_rows,
|
||||
int stride_cols, int rate_rows, int rate_cols, int pad_top,
|
||||
int pad_left, typename TTypes<T, 4>::Tensor output);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct DilationBackpropInput {
|
||||
// We assume that the tensor sizes are correct.
|
||||
// To avoid storing the argmax values during forward computation, we recompute
|
||||
// the argmax during backward computation, which is the reason why we provide
|
||||
// filter as argument to the backward computation routine.
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 4>::Tensor in_backprop);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct DilationBackpropFilter {
|
||||
// We assume that the tensor sizes are correct.
|
||||
// To avoid storing the argmax values during forward computation, we recompute
|
||||
// the argmax during backward computation, which is the reason why we provide
|
||||
// filter as argument to the backward computation routine.
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 3>::Tensor filter_backprop);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_
|
304
tensorflow/core/kernels/dilation_ops_gpu.cu.cc
Normal file
304
tensorflow/core/kernels/dilation_ops_gpu.cu.cc
Normal file
@ -0,0 +1,304 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include <cfloat>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/kernels/dilation_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
__global__ void DilationKernel(const int32 nthreads, const T* input_ptr,
|
||||
const T* filter_ptr, int batch, int input_rows,
|
||||
int input_cols, int depth, int filter_rows,
|
||||
int filter_cols, int output_rows,
|
||||
int output_cols, int stride_rows,
|
||||
int stride_cols, int rate_rows, int rate_cols,
|
||||
int pad_top, int pad_left, T* output_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int w_out = out_idx2 % output_cols;
|
||||
const int out_idx3 = out_idx2 / output_cols;
|
||||
const int h_out = out_idx3 % output_rows;
|
||||
const int b = out_idx3 / output_rows;
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val =
|
||||
input_ptr[d +
|
||||
depth *
|
||||
(w_in + input_cols * (h_in + input_rows * b))] +
|
||||
filter_ptr[d + depth * (w + filter_cols * h)];
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output_ptr[out_idx] = cur_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DilationBackpropInputKernel(
|
||||
const int32 nthreads, const T* input_ptr, const T* filter_ptr,
|
||||
const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
|
||||
int depth, int filter_rows, int filter_cols, int output_rows,
|
||||
int output_cols, int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left, T* in_backprop_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int w_out = out_idx2 % output_cols;
|
||||
const int out_idx3 = out_idx2 / output_cols;
|
||||
const int h_out = out_idx3 % output_rows;
|
||||
const int b = out_idx3 / output_rows;
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
int h_in_max = (h_beg < 0) ? 0 : h_beg;
|
||||
int w_in_max = (w_beg < 0) ? 0 : w_beg;
|
||||
// In the case of multiple argmax branches, we only back-propagate along the
|
||||
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
||||
// similarly to the max-pooling backward routines.
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val =
|
||||
input_ptr[d +
|
||||
depth *
|
||||
(w_in + input_cols * (h_in + input_rows * b))] +
|
||||
filter_ptr[d + depth * (w + filter_cols * h)];
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
h_in_max = h_in;
|
||||
w_in_max = w_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CudaAtomicAdd(
|
||||
in_backprop_ptr + d +
|
||||
depth * (w_in_max + input_cols * (h_in_max + input_rows * b)),
|
||||
out_backprop_ptr[out_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DilationBackpropFilterKernel(
|
||||
const int32 nthreads, const T* input_ptr, const T* filter_ptr,
|
||||
const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
|
||||
int depth, int filter_rows, int filter_cols, int output_rows,
|
||||
int output_cols, int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left, T* filter_backprop_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int w_out = out_idx2 % output_cols;
|
||||
const int out_idx3 = out_idx2 / output_cols;
|
||||
const int h_out = out_idx3 % output_rows;
|
||||
const int b = out_idx3 / output_rows;
|
||||
int h_beg = h_out * stride_rows - pad_top;
|
||||
int w_beg = w_out * stride_cols - pad_left;
|
||||
T cur_val = Eigen::NumTraits<T>::lowest();
|
||||
int h_max = 0;
|
||||
int w_max = 0;
|
||||
// In the case of multiple argmax branches, we only back-propagate along the
|
||||
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
||||
// similarly to the max-pooling backward routines.
|
||||
for (int h = 0; h < filter_rows; ++h) {
|
||||
const int h_in = h_beg + h * rate_rows;
|
||||
if (h_in >= 0 && h_in < input_rows) {
|
||||
for (int w = 0; w < filter_cols; ++w) {
|
||||
const int w_in = w_beg + w * rate_cols;
|
||||
if (w_in >= 0 && w_in < input_cols) {
|
||||
const T val =
|
||||
input_ptr[d +
|
||||
depth *
|
||||
(w_in + input_cols * (h_in + input_rows * b))] +
|
||||
filter_ptr[d + depth * (w + filter_cols * h)];
|
||||
if (val > cur_val) {
|
||||
cur_val = val;
|
||||
h_max = h;
|
||||
w_max = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CudaAtomicAdd(
|
||||
filter_backprop_ptr + d + depth * (w_max + filter_cols * h_max),
|
||||
out_backprop_ptr[out_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct Dilation<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter, int stride_rows,
|
||||
int stride_cols, int rate_rows, int rate_cols, int pad_top,
|
||||
int pad_left, typename TTypes<T, 4>::Tensor output) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = output.dimension(1);
|
||||
const int output_cols = output.dimension(2);
|
||||
|
||||
const int total_count = batch * output_rows * output_cols * depth;
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||
|
||||
DilationKernel<<<config.block_count, config.thread_per_block, 0,
|
||||
d.stream()>>>(
|
||||
config.virtual_thread_count, input.data(), filter.data(), batch,
|
||||
input_rows, input_cols, depth, filter_rows, filter_cols, output_rows,
|
||||
output_cols, stride_rows, stride_cols, rate_rows, rate_cols, pad_top,
|
||||
pad_left, output.data());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DilationBackpropInput<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 4>::Tensor in_backprop) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = out_backprop.dimension(1);
|
||||
const int output_cols = out_backprop.dimension(2);
|
||||
|
||||
int total_count;
|
||||
CudaLaunchConfig config;
|
||||
|
||||
// Initialize in_backprop with all zeros.
|
||||
total_count = batch * input_rows * input_cols * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
total_count, in_backprop.data());
|
||||
|
||||
// Accumulate.
|
||||
total_count = batch * output_rows * output_cols * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
DilationBackpropInputKernel<<<config.block_count, config.thread_per_block,
|
||||
0, d.stream()>>>(
|
||||
config.virtual_thread_count, input.data(), filter.data(),
|
||||
out_backprop.data(), batch, input_rows, input_cols, depth, filter_rows,
|
||||
filter_cols, output_rows, output_cols, stride_rows, stride_cols,
|
||||
rate_rows, rate_cols, pad_top, pad_left, in_backprop.data());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DilationBackpropFilter<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<T, 3>::ConstTensor filter,
|
||||
typename TTypes<T, 4>::ConstTensor out_backprop,
|
||||
int stride_rows, int stride_cols, int rate_rows,
|
||||
int rate_cols, int pad_top, int pad_left,
|
||||
typename TTypes<T, 3>::Tensor filter_backprop) {
|
||||
const int batch = input.dimension(0);
|
||||
const int input_rows = input.dimension(1);
|
||||
const int input_cols = input.dimension(2);
|
||||
const int depth = input.dimension(3);
|
||||
|
||||
const int filter_rows = filter.dimension(0);
|
||||
const int filter_cols = filter.dimension(1);
|
||||
|
||||
const int output_rows = out_backprop.dimension(1);
|
||||
const int output_cols = out_backprop.dimension(2);
|
||||
|
||||
int total_count;
|
||||
CudaLaunchConfig config;
|
||||
|
||||
// Initialize filter_backprop with all zeros.
|
||||
total_count = filter_rows * filter_cols * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
total_count, filter_backprop.data());
|
||||
|
||||
// Accumulate.
|
||||
total_count = batch * output_rows * output_cols * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
DilationBackpropFilterKernel<<<config.block_count, config.thread_per_block,
|
||||
0, d.stream()>>>(
|
||||
config.virtual_thread_count, input.data(), filter.data(),
|
||||
out_backprop.data(), batch, input_rows, input_cols, depth, filter_rows,
|
||||
filter_cols, output_rows, output_cols, stride_rows, stride_cols,
|
||||
rate_rows, rate_cols, pad_top, pad_left, filter_backprop.data());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
template struct functor::Dilation<GPUDevice, T>; \
|
||||
template struct functor::DilationBackpropInput<GPUDevice, T>; \
|
||||
template struct functor::DilationBackpropFilter<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -740,6 +740,99 @@ output: Gradients w.r.t. the input of `max_pool`.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("Dilation2D")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr("rates: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Doc(R"doc(
|
||||
Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors.
|
||||
|
||||
The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
|
||||
`filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each
|
||||
input channel is processed independently of the others with its own structuring
|
||||
function. The `output` tensor has shape
|
||||
`[batch, out_height, out_width, depth]`. The spatial dimensions of the output
|
||||
tensor depend on the `padding` algorithm. We currently only support the default
|
||||
"NHWC" `data_format`.
|
||||
|
||||
In detail, the grayscale morphological 2-D dilation is the max-sum correlation
|
||||
(for consistency with `conv2d`, we use unmirrored filters):
|
||||
|
||||
output[b, y, x, c] =
|
||||
max_{dy, dx} input[b,
|
||||
strides[1] * y + rates[1] * dy,
|
||||
strides[2] * x + rates[2] * dx,
|
||||
c] +
|
||||
filter[dy, dx, c]
|
||||
|
||||
Max-pooling is a special case when the filter has size equal to the pooling
|
||||
kernel size and contains all zeros.
|
||||
|
||||
Duality: The dilation of `input` by the `filter` is equal to the negation of
|
||||
the erosion of `-input` by the reflected `filter`.
|
||||
|
||||
input: 4-D with shape `[batch, in_height, in_width, depth]`.
|
||||
filter: 3-D with shape `[filter_height, filter_width, depth]`.
|
||||
strides: The stride of the sliding window for each dimension of the input
|
||||
tensor. Must be: `[1, stride_height, stride_width, 1]`.
|
||||
rates: The input stride for atrous morphological dilation. Must be:
|
||||
`[1, rate_height, rate_width, 1]`.
|
||||
padding: The type of padding algorithm to use.
|
||||
output: 4-D with shape `[batch, out_height, out_width, depth]`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Dilation2DBackpropInput")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Input("out_backprop: T")
|
||||
.Output("in_backprop: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr("rates: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Doc(R"doc(
|
||||
Computes the gradient of morphological 2-D dilation with respect to the input.
|
||||
|
||||
input: 4-D with shape `[batch, in_height, in_width, depth]`.
|
||||
filter: 3-D with shape `[filter_height, filter_width, depth]`.
|
||||
out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`.
|
||||
in_backprop: 4-D with shape `[batch, in_height, in_width, depth]`.
|
||||
strides: 1-D of length 4. The stride of the sliding window for each dimension of
|
||||
the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
|
||||
rates: 1-D of length 4. The input stride for atrous morphological dilation.
|
||||
Must be: `[1, rate_height, rate_width, 1]`.
|
||||
padding: The type of padding algorithm to use.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Dilation2DBackpropFilter")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Input("out_backprop: T")
|
||||
.Output("filter_backprop: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr("rates: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Doc(R"doc(
|
||||
Computes the gradient of morphological 2-D dilation with respect to the filter.
|
||||
|
||||
input: 4-D with shape `[batch, in_height, in_width, depth]`.
|
||||
filter: 3-D with shape `[filter_height, filter_width, depth]`.
|
||||
out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`.
|
||||
filter_backprop: 3-D with shape `[filter_height, filter_width, depth]`.
|
||||
strides: 1-D of length 4. The stride of the sliding window for each dimension of
|
||||
the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
|
||||
rates: 1-D of length 4. The input stride for atrous morphological dilation.
|
||||
Must be: `[1, rate_height, rate_width, 1]`.
|
||||
padding: The type of padding algorithm to use.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("Relu")
|
||||
.Input("features: T")
|
||||
.Output("activations: T")
|
||||
|
541
tensorflow/python/kernel_tests/morphological_ops_test.py
Normal file
541
tensorflow/python/kernel_tests/morphological_ops_test.py
Normal file
@ -0,0 +1,541 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for morphological filtering operations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class DilationTest(tf.test.TestCase):
|
||||
|
||||
def _VerifyValues(self, image, kernel, strides, rates, padding, out, use_gpu):
|
||||
"""Verifies the output values of the dilation function.
|
||||
|
||||
Args:
|
||||
image: Input tensor with shape: [batch, in_height, in_width, channels].
|
||||
kernel: Filter tensor with shape: [filter_height, filter_width, channels].
|
||||
strides: Output strides, specified as [stride_height, stride_width].
|
||||
rates: Atrous rates, specified as [rate_height, rate_width].
|
||||
padding: Padding type.
|
||||
out: Expected output.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
strides = [1] + strides + [1]
|
||||
rates = [1] + rates + [1]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
out_tensor = tf.nn.dilation2d(
|
||||
tf.constant(image),
|
||||
tf.constant(kernel),
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name="dilation2d")
|
||||
self.assertAllClose(out, out_tensor.eval())
|
||||
|
||||
def _testDilationValidPadding(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [1, 1, 1, 1]
|
||||
out = [[[[.5]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationSamePadding(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [1, 2, 2, 1]
|
||||
out = [[[[.5], [.6]], [[.7], [.8]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationSamePaddingDepth(self, use_gpu):
|
||||
# [1, 2, 2, 3]
|
||||
image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]]
|
||||
# [2, 2, 3]
|
||||
kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]]
|
||||
# [1, 2, 2, 3]
|
||||
out = [[[[.5, .7, .3], [.6, .8, .4]], [[.7, .9, .5], [.8, 1., .6]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationSamePaddingBatch(self, use_gpu):
|
||||
# [2, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [2, 2, 2, 1]
|
||||
out = [[[[.5], [.6]], [[.7], [.8]]], [[[.6], [.7]], [[.8], [.9]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationValidPaddingNonSquareWindow(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [1, 2, 1]
|
||||
kernel = [[[.4], [.3]]]
|
||||
# [1, 2, 1, 1]
|
||||
out = [[[[.5]], [[.7]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationSamePaddingRate(self, use_gpu):
|
||||
# [1, 3, 3, 1]
|
||||
image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.2]]]
|
||||
# Because rate = 2, the effective kernel is [3, 3, 1]:
|
||||
# kernel_eff = [[[.4], [.0], [.3]],
|
||||
# [[.0], [.0], [.0]],
|
||||
# [[.1], [.0], [.2]]]
|
||||
# [1, 3, 3, 1]
|
||||
out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[2, 2],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationValidPaddingUnevenStride(self, use_gpu):
|
||||
# [1, 3, 3, 1]
|
||||
image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
|
||||
[[.9], [1.0], [1.1], [1.2]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.2]]]
|
||||
# [1, 2, 2, 1]
|
||||
out = [[[[.8], [1.0]], [[1.2], [1.4]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 2],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testDilation(self):
|
||||
for use_gpu in True, False:
|
||||
self._testDilationValidPadding(use_gpu)
|
||||
self._testDilationSamePadding(use_gpu)
|
||||
self._testDilationSamePaddingDepth(use_gpu)
|
||||
self._testDilationSamePaddingBatch(use_gpu)
|
||||
self._testDilationValidPaddingNonSquareWindow(use_gpu)
|
||||
self._testDilationSamePaddingRate(use_gpu)
|
||||
self._testDilationValidPaddingUnevenStride(use_gpu)
|
||||
|
||||
def _ConstructAndTestGradient(self, image_shape, kernel_shape, strides, rates,
|
||||
padding, use_gpu):
|
||||
"""Verifies the gradients of the dilation function.
|
||||
|
||||
Args:
|
||||
image_shape: Input shape, [batch, in_height, in_width, channels].
|
||||
kernel_shape: Filter shape, [filter_height, filter_width, channels].
|
||||
strides: Output strides, specified as [stride_height, stride_width].
|
||||
rates: Atrous rates, specified as [rate_height, rate_width].
|
||||
padding: Padding type.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
assert image_shape[3] == kernel_shape[2]
|
||||
|
||||
np.random.seed(1) # Make it reproducible.
|
||||
image = np.random.random_sample(image_shape).astype(np.float32)
|
||||
kernel = np.random.random_sample(kernel_shape).astype(np.float32)
|
||||
image_init = np.random.random_sample(image_shape).astype(np.float32)
|
||||
kernel_init = np.random.random_sample(kernel_shape).astype(np.float32)
|
||||
|
||||
strides = [1] + strides + [1]
|
||||
rates = [1] + rates + [1]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
image_tensor = tf.constant(image, shape=image_shape, name="input")
|
||||
kernel_tensor = tf.constant(kernel, shape=kernel_shape, name="filter")
|
||||
out_tensor = tf.nn.dilation2d(image_tensor,
|
||||
kernel_tensor,
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name="dilation2d")
|
||||
out_shape = out_tensor.eval().shape
|
||||
|
||||
# Small delta is necessary for argmax to remain the same.
|
||||
err = tf.test.compute_gradient_error([image_tensor, kernel_tensor],
|
||||
[image_shape, kernel_shape],
|
||||
out_tensor,
|
||||
out_shape, [image_init, kernel_init],
|
||||
delta=1e-3)
|
||||
|
||||
print("Dilation gradient error = %f" % err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def _testDilationGradValidPadding_1x1x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[1, 1, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradSamePadding_1x1x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[1, 1, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradSamePadding_1x1x2(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 2],
|
||||
kernel_shape=[1, 1, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradValidPadding_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradSamePadding_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradSamePaddingBatch_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[4, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testDilationGradSamePadding_2x2x4(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 4],
|
||||
kernel_shape=[2, 2, 4],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testDilationGrad(self):
|
||||
for use_gpu in True, False:
|
||||
self._testDilationGradValidPadding_1x1x1(use_gpu)
|
||||
self._testDilationGradSamePadding_1x1x1(use_gpu)
|
||||
self._testDilationGradSamePadding_1x1x2(use_gpu)
|
||||
self._testDilationGradValidPadding_2x2x1(use_gpu)
|
||||
self._testDilationGradSamePadding_2x2x1(use_gpu)
|
||||
self._testDilationGradSamePaddingBatch_2x2x1(use_gpu)
|
||||
self._testDilationGradSamePadding_2x2x4(use_gpu)
|
||||
|
||||
|
||||
class ErosionTest(tf.test.TestCase):
|
||||
|
||||
def _VerifyValues(self, image, kernel, strides, rates, padding, out, use_gpu):
|
||||
"""Verifies the output values of the erosion function.
|
||||
|
||||
Args:
|
||||
image: Input tensor with shape: [batch, in_height, in_width, channels].
|
||||
kernel: Filter tensor with shape: [filter_height, filter_width, channels].
|
||||
strides: Output strides, specified as [stride_height, stride_width].
|
||||
rates: Atrous rates, specified as [rate_height, rate_width].
|
||||
padding: Padding type.
|
||||
out: Expected output.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
strides = [1] + strides + [1]
|
||||
rates = [1] + rates + [1]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
out_tensor = tf.nn.erosion2d(
|
||||
tf.constant(image),
|
||||
tf.constant(kernel),
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name="erosion2d")
|
||||
self.assertAllClose(out, out_tensor.eval())
|
||||
|
||||
def _testErosionValidPadding(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [1, 1, 1, 1]
|
||||
out = [[[[.0]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionSamePadding(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [1, 2, 2, 1]
|
||||
out = [[[[.0], [.1]], [[.3], [.4]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionSamePaddingDepth(self, use_gpu):
|
||||
# [1, 2, 2, 3]
|
||||
image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]]
|
||||
# [2, 2, 3]
|
||||
kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]]
|
||||
# [1, 2, 2, 3]
|
||||
out = [[[[.0, .0, .0], [.1, .1, .1]], [[.3, .3, .3], [.4, .4, .4]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionSamePaddingBatch(self, use_gpu):
|
||||
# [2, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.0]]]
|
||||
# [2, 2, 2, 1]
|
||||
out = [[[[.0], [.1]], [[.3], [.4]]], [[[.1], [.2]], [[.4], [.5]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionValidPaddingNonSquareWindow(self, use_gpu):
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[.1], [.2]], [[.3], [.4]]]]
|
||||
# [1, 2, 1]
|
||||
kernel = [[[.4], [.3]]]
|
||||
# [1, 2, 1, 1]
|
||||
out = [[[[-.2]], [[.0]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionSamePaddingRate(self, use_gpu):
|
||||
# [1, 3, 3, 1]
|
||||
image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.2]]]
|
||||
# Because rate = 2, the effective kernel is [3, 3, 1]:
|
||||
# kernel_eff = [[[.4], [.0], [.3]],
|
||||
# [[.0], [.0], [.0]],
|
||||
# [[.1], [.0], [.2]]]
|
||||
# [1, 3, 3, 1]
|
||||
out = [[[[.1], [.1], [.2]], [[0.1], [-.1], [.0]], [[.4], [.2], [.3]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 1],
|
||||
rates=[2, 2],
|
||||
padding="SAME",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionValidPaddingUnevenStride(self, use_gpu):
|
||||
# [1, 3, 3, 1]
|
||||
image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
|
||||
[[.9], [1.0], [1.1], [1.2]]]]
|
||||
# [2, 2, 1]
|
||||
kernel = [[[.4], [.3]], [[.1], [.2]]]
|
||||
# [1, 2, 2, 1]
|
||||
out = [[[[-.1], [.1]], [[.3], [.5]]]]
|
||||
self._VerifyValues(image,
|
||||
kernel,
|
||||
strides=[1, 2],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
out=out,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testErosion(self):
|
||||
for use_gpu in True, False:
|
||||
self._testErosionValidPadding(use_gpu)
|
||||
self._testErosionSamePadding(use_gpu)
|
||||
self._testErosionSamePaddingDepth(use_gpu)
|
||||
self._testErosionSamePaddingBatch(use_gpu)
|
||||
self._testErosionValidPaddingNonSquareWindow(use_gpu)
|
||||
self._testErosionSamePaddingRate(use_gpu)
|
||||
self._testErosionValidPaddingUnevenStride(use_gpu)
|
||||
|
||||
def _ConstructAndTestGradient(self, image_shape, kernel_shape, strides, rates,
|
||||
padding, use_gpu):
|
||||
"""Verifies the gradients of the erosion function.
|
||||
|
||||
Args:
|
||||
image_shape: Input shape, [batch, in_height, in_width, channels].
|
||||
kernel_shape: Filter shape, [filter_height, filter_width, channels].
|
||||
strides: Output strides, specified as [stride_height, stride_width].
|
||||
rates: Atrous rates, specified as [rate_height, rate_width].
|
||||
padding: Padding type.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
assert image_shape[3] == kernel_shape[2]
|
||||
|
||||
np.random.seed(1) # Make it reproducible.
|
||||
image = np.random.random_sample(image_shape).astype(np.float32)
|
||||
kernel = np.random.random_sample(kernel_shape).astype(np.float32)
|
||||
image_init = np.random.random_sample(image_shape).astype(np.float32)
|
||||
kernel_init = np.random.random_sample(kernel_shape).astype(np.float32)
|
||||
|
||||
strides = [1] + strides + [1]
|
||||
rates = [1] + rates + [1]
|
||||
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
image_tensor = tf.constant(image, shape=image_shape, name="input")
|
||||
kernel_tensor = tf.constant(kernel, shape=kernel_shape, name="filter")
|
||||
out_tensor = tf.nn.erosion2d(image_tensor,
|
||||
kernel_tensor,
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name="erosion2d")
|
||||
out_shape = out_tensor.eval().shape
|
||||
|
||||
# Small delta is necessary for argmax to remain the same.
|
||||
err = tf.test.compute_gradient_error([image_tensor, kernel_tensor],
|
||||
[image_shape, kernel_shape],
|
||||
out_tensor,
|
||||
out_shape, [image_init, kernel_init],
|
||||
delta=1e-3)
|
||||
|
||||
print("Erosion gradient error = %f" % err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def _testErosionGradValidPadding_1x1x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[1, 1, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradSamePadding_1x1x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[1, 1, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradSamePadding_1x1x2(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 2],
|
||||
kernel_shape=[1, 1, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradValidPadding_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradSamePadding_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradSamePaddingBatch_2x2x1(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[4, 3, 3, 1],
|
||||
kernel_shape=[2, 2, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testErosionGradSamePadding_2x2x4(self, use_gpu):
|
||||
self._ConstructAndTestGradient(image_shape=[1, 3, 3, 4],
|
||||
kernel_shape=[2, 2, 4],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testErosionGrad(self):
|
||||
for use_gpu in True, False:
|
||||
self._testErosionGradValidPadding_1x1x1(use_gpu)
|
||||
self._testErosionGradSamePadding_1x1x1(use_gpu)
|
||||
self._testErosionGradSamePadding_1x1x2(use_gpu)
|
||||
self._testErosionGradValidPadding_2x2x1(use_gpu)
|
||||
self._testErosionGradSamePadding_2x2x1(use_gpu)
|
||||
self._testErosionGradSamePaddingBatch_2x2x1(use_gpu)
|
||||
self._testErosionGradSamePadding_2x2x4(use_gpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -131,6 +131,46 @@ to the `Convolution` section for details about the padding calculation.
|
||||
@@avg_pool3d
|
||||
@@max_pool3d
|
||||
|
||||
## Morphological filtering
|
||||
|
||||
Morphological operators are non-linear filters used in image processing.
|
||||
|
||||
[Greyscale morphological dilation]
|
||||
(https://en.wikipedia.org/wiki/Dilation_(morphology)) is the max-sum counterpart
|
||||
of standard sum-product convolution:
|
||||
|
||||
output[b, y, x, c] =
|
||||
max_{dy, dx} input[b,
|
||||
strides[1] * y + rates[1] * dy,
|
||||
strides[2] * x + rates[2] * dx,
|
||||
c] +
|
||||
filter[dy, dx, c]
|
||||
|
||||
The `filter` is usually called structuring function. Max-pooling is a special
|
||||
case of greyscale morphological dilation when the filter assumes all-zero
|
||||
values (a.k.a. flat structuring function).
|
||||
|
||||
[Greyscale morphological erosion]
|
||||
(https://en.wikipedia.org/wiki/Erosion_(morphology)) is the min-sum counterpart
|
||||
of standard sum-product convolution:
|
||||
|
||||
output[b, y, x, c] =
|
||||
min_{dy, dx} input[b,
|
||||
strides[1] * y - rates[1] * dy,
|
||||
strides[2] * x - rates[2] * dx,
|
||||
c] -
|
||||
filter[dy, dx, c]
|
||||
|
||||
Dilation and erosion are dual to each other. The dilation of the input signal
|
||||
`f` by the structuring signal `g` is equal to the negation of the erosion of
|
||||
`-f` by the reflected `g`, and vice versa.
|
||||
|
||||
Striding and padding is carried out in exactly the same way as in standard
|
||||
convolution. Please refer to the `Convolution` section for details.
|
||||
|
||||
@@dilation2d
|
||||
@@erosion2d
|
||||
|
||||
## Normalization
|
||||
|
||||
Normalization is useful to prevent neurons from saturating when inputs may
|
||||
|
@ -270,6 +270,18 @@ def _DepthwiseConv2dNativeGrad(op, grad):
|
||||
]
|
||||
|
||||
|
||||
@ops.RegisterGradient("Dilation2D")
|
||||
def _Dilation2DGrad(op, grad):
|
||||
return [nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
|
||||
op.get_attr("strides"),
|
||||
op.get_attr("rates"),
|
||||
op.get_attr("padding")),
|
||||
nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
|
||||
op.get_attr("strides"),
|
||||
op.get_attr("rates"),
|
||||
op.get_attr("padding"))]
|
||||
|
||||
|
||||
@ops.RegisterGradient("LRN")
|
||||
def _LRNGrad(op, grad):
|
||||
depth_radius = op.get_attr("depth_radius")
|
||||
|
@ -1068,4 +1068,142 @@ def conv1d(value, filters, stride, padding,
|
||||
data_format=data_format)
|
||||
return array_ops.squeeze(result, [1])
|
||||
|
||||
|
||||
@ops.RegisterShape("Dilation2D")
|
||||
def _Dilation2DShape(op):
|
||||
"""Shape function for Dilation2D op."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank(4)
|
||||
filter_shape = op.inputs[1].get_shape().with_rank(3)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
in_rows = input_shape[1]
|
||||
in_cols = input_shape[2]
|
||||
depth = input_shape[3]
|
||||
|
||||
filter_rows = filter_shape[0]
|
||||
filter_cols = filter_shape[1]
|
||||
# Check that the input depths are compatible.
|
||||
input_shape[3].assert_is_compatible_with(filter_shape[2])
|
||||
|
||||
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
|
||||
if stride_b != 1 or stride_d != 1:
|
||||
raise ValueError("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions.")
|
||||
|
||||
rate_b, rate_r, rate_c, rate_d = op.get_attr("rates")
|
||||
if rate_b != 1 or rate_d != 1:
|
||||
raise ValueError("Current implementation does not yet support "
|
||||
"rates in the batch and depth dimensions.")
|
||||
|
||||
filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1)
|
||||
filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1)
|
||||
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
|
||||
filter_rows_eff,
|
||||
filter_cols_eff,
|
||||
stride_r, stride_c,
|
||||
padding)
|
||||
|
||||
output_shape = [batch_size, out_rows, out_cols, depth]
|
||||
return [tensor_shape.TensorShape(output_shape)]
|
||||
|
||||
|
||||
@ops.RegisterShape("Dilation2DBackpropInput")
|
||||
def _Dilation2DBackpropInputShape(op):
|
||||
"""Shape function for Dilation2DBackpropInput op."""
|
||||
return [op.inputs[0].get_shape()]
|
||||
|
||||
|
||||
@ops.RegisterShape("Dilation2DBackpropFilter")
|
||||
def _Dilation2DBackpropFilterShape(op):
|
||||
"""Shape function for Dilation2DBackpropFilter op."""
|
||||
return [op.inputs[1].get_shape()]
|
||||
|
||||
|
||||
@ops.RegisterStatistics("Dilation2D", "flops")
|
||||
def _calc_dilation2d_flops(graph, node):
|
||||
"""Calculates the compute resources needed for Dilation2D."""
|
||||
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||
input_shape.assert_is_fully_defined()
|
||||
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||
node.input[1])
|
||||
filter_shape.assert_is_fully_defined()
|
||||
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||
output_shape.assert_is_fully_defined()
|
||||
filter_height = int(filter_shape[0])
|
||||
filter_width = int(filter_shape[1])
|
||||
output_count = np.prod(output_shape.as_list())
|
||||
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
|
||||
|
||||
|
||||
@ops.RegisterStatistics("Dilation2D", "weight_parameters")
|
||||
def _calc_dilation2d_weight_params(graph, node):
|
||||
"""Calculates the on-disk size of the weights for Dilation2D."""
|
||||
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||
node.input[1])
|
||||
filter_shape.assert_is_fully_defined()
|
||||
filter_height = int(filter_shape[0])
|
||||
filter_width = int(filter_shape[1])
|
||||
filter_depth = int(filter_shape[2])
|
||||
return ops.OpStats("weight_parameters",
|
||||
(filter_height * filter_width * filter_depth))
|
||||
|
||||
|
||||
def erosion2d(value, kernel, strides, rates, padding, name=None):
|
||||
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
|
||||
|
||||
The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
|
||||
`kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e.,
|
||||
each input channel is processed independently of the others with its own
|
||||
structuring function. The `output` tensor has shape
|
||||
`[batch, out_height, out_width, depth]`. The spatial dimensions of the
|
||||
output tensor depend on the `padding` algorithm. We currently only support the
|
||||
default "NHWC" `data_format`.
|
||||
|
||||
In detail, the grayscale morphological 2-D erosion is given by:
|
||||
|
||||
output[b, y, x, c] =
|
||||
min_{dy, dx} value[b,
|
||||
strides[1] * y - rates[1] * dy,
|
||||
strides[2] * x - rates[2] * dx,
|
||||
c] -
|
||||
kernel[dy, dx, c]
|
||||
|
||||
Duality: The erosion of `value` by the `kernel` is equal to the negation of
|
||||
the dilation of `-value` by the reflected `kernel`.
|
||||
|
||||
Args:
|
||||
value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
|
||||
kernel: A `Tensor`. Must have the same type as `value`.
|
||||
3-D with shape `[kernel_height, kernel_width, depth]`.
|
||||
strides: A list of `ints` that has length `>= 4`.
|
||||
1-D of length 4. The stride of the sliding window for each dimension of
|
||||
the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
|
||||
rates: A list of `ints` that has length `>= 4`.
|
||||
1-D of length 4. The input stride for atrous morphological dilation.
|
||||
Must be: `[1, rate_height, rate_width, 1]`.
|
||||
padding: A `string` from: `"SAME", "VALID"`.
|
||||
The type of padding algorithm to use.
|
||||
name: A name for the operation (optional). If not specified "erosion2d"
|
||||
is used.
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `value`.
|
||||
4-D with shape `[batch, out_height, out_width, depth]`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `value` depth does not match `kernel`' shape, or if
|
||||
padding is other than `'VALID'` or `'SAME'`.
|
||||
"""
|
||||
with ops.op_scope([value, kernel], name, "erosion2d") as name:
|
||||
# Reduce erosion to dilation by duality.
|
||||
return math_ops.neg(gen_nn_ops.dilation2d(input=math_ops.neg(value),
|
||||
filter=array_ops.reverse(
|
||||
kernel, [True, True, False]),
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name=name))
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
Loading…
Reference in New Issue
Block a user