Support int32 argmax output in MaxPoolWithArgmax

The documentation (https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax)
indicates that output_dtype supports both int32 and int64 but currently
only kernels for int64 are registered.

PiperOrigin-RevId: 347736377
Change-Id: I01ba684fb1486d4f88b47c29d6a6d5f7b102578c
This commit is contained in:
Thai Nguyen 2020-12-15 18:56:13 -08:00 committed by TensorFlower Gardener
parent c1e9462055
commit 08cb2c1d43
2 changed files with 37 additions and 22 deletions
tensorflow
core/kernels
python/kernel_tests

View File

@ -19,7 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/maxpooling_op.h"
#include <type_traits>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
@ -56,7 +58,7 @@ typedef Eigen::GpuDevice GPUDevice;
const int kInvalidMaxPoolingIndex = -1;
template <typename Device, typename T>
template <typename Device, typename T, typename Targmax>
static void SpatialMaxPoolWithArgMaxHelper(
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
@ -67,13 +69,17 @@ static void SpatialMaxPoolWithArgMaxHelper(
errors::Internal(
"SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
"to be True when input_backprop != nullptr"));
OP_REQUIRES(
context, (std::is_same<Targmax, int64>::value),
errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax "
"to be int64 when input_backprop != nullptr"));
}
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
typedef Eigen::Map<Eigen::Matrix<Targmax, Eigen::Dynamic, Eigen::Dynamic>>
EigenIndexMatrixMap;
ConstEigenMatrixMap in_mat(
@ -83,7 +89,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
output->flat<T>().data(), params.depth,
params.out_width * params.out_height * params.tensor_in_batch);
EigenIndexMatrixMap out_arg_max_mat(
output_arg_max->flat<int64>().data(), params.depth,
output_arg_max->flat<Targmax>().data(), params.depth,
params.out_width * params.out_height * params.tensor_in_batch);
const DeviceBase::CpuWorkerThreads& worker_threads =
@ -150,7 +156,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
for (int d = 0; d < depth; ++d) {
const T& input_ref = in_mat.coeffRef(d, in_index);
T& output_ref = out_mat.coeffRef(d, out_index);
int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
Targmax& out_arg_max_ref =
out_arg_max_mat.coeffRef(d, out_index);
if (output_ref < input_ref ||
out_arg_max_ref == kInvalidMaxPoolingIndex) {
output_ref = input_ref;
@ -319,7 +326,7 @@ class MaxPoolingGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, output_shape, &output));
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, int64>(
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
out_backprop, params, true);
}
@ -900,22 +907,22 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
TensorFormat data_format_;
};
template <typename Device, typename T>
template <typename Device, typename T, typename Targmax>
struct LaunchMaxPoolingWithArgmax;
template <typename T>
struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
template <typename T, typename Targmax>
struct LaunchMaxPoolingWithArgmax<CPUDevice, T, Targmax> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output, Tensor* argmax,
bool propagate_nans, bool include_batch_in_index) {
Tensor unused;
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(context, output, argmax,
nullptr, input, unused, params,
include_batch_in_index);
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, Targmax>(
context, output, argmax, /*input_backprop=*/nullptr, input, unused,
params, include_batch_in_index);
}
};
template <typename Device, typename T>
template <typename Device, typename T, typename Targmax>
class MaxPoolingWithArgmaxOp : public OpKernel {
public:
explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
@ -959,7 +966,7 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
Tensor* argmax = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
LaunchMaxPoolingWithArgmax<Device, T>::launch(
LaunchMaxPoolingWithArgmax<Device, T, Targmax>::launch(
context, params, tensor_in, output, argmax, propagate_nans_,
include_batch_in_index_);
}
@ -1027,6 +1034,7 @@ struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
}
};
// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op.
template <typename Device, typename T>
class MaxPoolingGradWithArgmaxOp : public OpKernel {
public:
@ -1363,7 +1371,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
};
template <typename T>
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T, int64> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output, Tensor* argmax,
bool propagate_nans, bool include_batch_in_index) {
@ -1456,7 +1464,7 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
.Device(DEVICE_##D) \
.TypeConstraint<int64>("Targmax") \
.TypeConstraint<T>("T"), \
MaxPoolingWithArgmaxOp<D##Device, T>); \
MaxPoolingWithArgmaxOp<D##Device, T, int64>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
.Device(DEVICE_##D) \
.TypeConstraint<T>("T") \
@ -1470,7 +1478,12 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
MaxPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MaxPoolingV2Op<CPUDevice, T>);
MaxPoolingV2Op<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
.Device(DEVICE_CPU) \
.TypeConstraint<int32>("Targmax") \
.TypeConstraint<T>("T"), \
MaxPoolingWithArgmaxOp<CPUDevice, T, int32>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
#undef REGISTER_CPU_ONLY_POOL_KERNELS

View File

@ -1004,12 +1004,14 @@ class PoolingTest(test.TestCase):
]
Config = collections.namedtuple(
"Config", ["use_gpu", "include_batch_in_index", "argmax"])
"Config", ["use_gpu", "include_batch_in_index", "argmax", "Targmax"])
configs = [
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8]),
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17]),
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8]),
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17])
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int32),
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int32),
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
]
for config in configs:
@ -1019,7 +1021,7 @@ class PoolingTest(test.TestCase):
t,
ksize=[1, 2, 2, 1],
strides=[1, 1, 1, 1],
Targmax=dtypes.int64,
Targmax=config.Targmax,
padding="VALID",
include_batch_in_index=config.include_batch_in_index)
out, argmax = self.evaluate([out_op, argmax_op])