From 08cb2c1d43f3aab4458caabb0eabe8e9d2f58b9c Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Tue, 15 Dec 2020 18:56:13 -0800 Subject: [PATCH] Support int32 argmax output in MaxPoolWithArgmax The documentation (https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax) indicates that output_dtype supports both int32 and int64 but currently only kernels for int64 are registered. PiperOrigin-RevId: 347736377 Change-Id: I01ba684fb1486d4f88b47c29d6a6d5f7b102578c --- tensorflow/core/kernels/maxpooling_op.cc | 45 ++++++++++++------- .../python/kernel_tests/pooling_ops_test.py | 14 +++--- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 22e640fbb9a..b60d54533be 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -19,7 +19,9 @@ limitations under the License. #include "tensorflow/core/kernels/maxpooling_op.h" +#include #include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/bounds_check.h" @@ -56,7 +58,7 @@ typedef Eigen::GpuDevice GPUDevice; const int kInvalidMaxPoolingIndex = -1; -template +template static void SpatialMaxPoolWithArgMaxHelper( OpKernelContext* context, Tensor* output, Tensor* output_arg_max, Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop, @@ -67,13 +69,17 @@ static void SpatialMaxPoolWithArgMaxHelper( errors::Internal( "SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index " "to be True when input_backprop != nullptr")); + OP_REQUIRES( + context, (std::is_same::value), + errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax " + "to be int64 when input_backprop != nullptr")); } typedef Eigen::Map> ConstEigenMatrixMap; typedef Eigen::Map> EigenMatrixMap; - typedef Eigen::Map> + typedef Eigen::Map> EigenIndexMatrixMap; ConstEigenMatrixMap in_mat( @@ -83,7 +89,7 @@ static void SpatialMaxPoolWithArgMaxHelper( output->flat().data(), params.depth, params.out_width * params.out_height * params.tensor_in_batch); EigenIndexMatrixMap out_arg_max_mat( - output_arg_max->flat().data(), params.depth, + output_arg_max->flat().data(), params.depth, params.out_width * params.out_height * params.tensor_in_batch); const DeviceBase::CpuWorkerThreads& worker_threads = @@ -150,7 +156,8 @@ static void SpatialMaxPoolWithArgMaxHelper( for (int d = 0; d < depth; ++d) { const T& input_ref = in_mat.coeffRef(d, in_index); T& output_ref = out_mat.coeffRef(d, out_index); - int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index); + Targmax& out_arg_max_ref = + out_arg_max_mat.coeffRef(d, out_index); if (output_ref < input_ref || out_arg_max_ref == kInvalidMaxPoolingIndex) { output_ref = input_ref; @@ -319,7 +326,7 @@ class MaxPoolingGradOp : public OpKernel { OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, output_shape, &output)); - SpatialMaxPoolWithArgMaxHelper( + SpatialMaxPoolWithArgMaxHelper( context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in, out_backprop, params, true); } @@ -900,22 +907,22 @@ class MaxPoolingNoMaskV2Op : public OpKernel { TensorFormat data_format_; }; -template +template struct LaunchMaxPoolingWithArgmax; -template -struct LaunchMaxPoolingWithArgmax { +template +struct LaunchMaxPoolingWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output, Tensor* argmax, bool propagate_nans, bool include_batch_in_index) { Tensor unused; - SpatialMaxPoolWithArgMaxHelper(context, output, argmax, - nullptr, input, unused, params, - include_batch_in_index); + SpatialMaxPoolWithArgMaxHelper( + context, output, argmax, /*input_backprop=*/nullptr, input, unused, + params, include_batch_in_index); } }; -template +template class MaxPoolingWithArgmaxOp : public OpKernel { public: explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context) @@ -959,7 +966,7 @@ class MaxPoolingWithArgmaxOp : public OpKernel { Tensor* argmax = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax)); - LaunchMaxPoolingWithArgmax::launch( + LaunchMaxPoolingWithArgmax::launch( context, params, tensor_in, output, argmax, propagate_nans_, include_batch_in_index_); } @@ -1027,6 +1034,7 @@ struct LaunchMaxPoolingGradWithArgmax { } }; +// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op. template class MaxPoolingGradWithArgmaxOp : public OpKernel { public: @@ -1363,7 +1371,7 @@ struct LaunchMaxPoolingNoMask { }; template -struct LaunchMaxPoolingWithArgmax { +struct LaunchMaxPoolingWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output, Tensor* argmax, bool propagate_nans, bool include_batch_in_index) { @@ -1456,7 +1464,7 @@ struct LaunchMaxPoolingGradGradWithArgmax { .Device(DEVICE_##D) \ .TypeConstraint("Targmax") \ .TypeConstraint("T"), \ - MaxPoolingWithArgmaxOp); \ + MaxPoolingWithArgmaxOp); \ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \ .Device(DEVICE_##D) \ .TypeConstraint("T") \ @@ -1470,7 +1478,12 @@ struct LaunchMaxPoolingGradGradWithArgmax { MaxPoolingOp); \ REGISTER_KERNEL_BUILDER( \ Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint("T"), \ - MaxPoolingV2Op); + MaxPoolingV2Op); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Targmax") \ + .TypeConstraint("T"), \ + MaxPoolingWithArgmaxOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS); #undef REGISTER_CPU_ONLY_POOL_KERNELS diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 20699f5de49..98e043e8a80 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -1004,12 +1004,14 @@ class PoolingTest(test.TestCase): ] Config = collections.namedtuple( - "Config", ["use_gpu", "include_batch_in_index", "argmax"]) + "Config", ["use_gpu", "include_batch_in_index", "argmax", "Targmax"]) configs = [ - Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8]), - Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17]), - Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8]), - Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17]) + Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64), + Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64), + Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int32), + Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int32), + Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64), + Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64), ] for config in configs: @@ -1019,7 +1021,7 @@ class PoolingTest(test.TestCase): t, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], - Targmax=dtypes.int64, + Targmax=config.Targmax, padding="VALID", include_batch_in_index=config.include_batch_in_index) out, argmax = self.evaluate([out_op, argmax_op])