Support int32 argmax output in MaxPoolWithArgmax
The documentation (https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax) indicates that output_dtype supports both int32 and int64 but currently only kernels for int64 are registered. PiperOrigin-RevId: 347736377 Change-Id: I01ba684fb1486d4f88b47c29d6a6d5f7b102578c
This commit is contained in:
parent
c1e9462055
commit
08cb2c1d43
tensorflow
@ -19,7 +19,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/maxpooling_op.h"
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
@ -56,7 +58,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
const int kInvalidMaxPoolingIndex = -1;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Targmax>
|
||||
static void SpatialMaxPoolWithArgMaxHelper(
|
||||
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
|
||||
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
|
||||
@ -67,13 +69,17 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
errors::Internal(
|
||||
"SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
|
||||
"to be True when input_backprop != nullptr"));
|
||||
OP_REQUIRES(
|
||||
context, (std::is_same<Targmax, int64>::value),
|
||||
errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax "
|
||||
"to be int64 when input_backprop != nullptr"));
|
||||
}
|
||||
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
ConstEigenMatrixMap;
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
EigenMatrixMap;
|
||||
typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
typedef Eigen::Map<Eigen::Matrix<Targmax, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
EigenIndexMatrixMap;
|
||||
|
||||
ConstEigenMatrixMap in_mat(
|
||||
@ -83,7 +89,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
output->flat<T>().data(), params.depth,
|
||||
params.out_width * params.out_height * params.tensor_in_batch);
|
||||
EigenIndexMatrixMap out_arg_max_mat(
|
||||
output_arg_max->flat<int64>().data(), params.depth,
|
||||
output_arg_max->flat<Targmax>().data(), params.depth,
|
||||
params.out_width * params.out_height * params.tensor_in_batch);
|
||||
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
@ -150,7 +156,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
const T& input_ref = in_mat.coeffRef(d, in_index);
|
||||
T& output_ref = out_mat.coeffRef(d, out_index);
|
||||
int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
|
||||
Targmax& out_arg_max_ref =
|
||||
out_arg_max_mat.coeffRef(d, out_index);
|
||||
if (output_ref < input_ref ||
|
||||
out_arg_max_ref == kInvalidMaxPoolingIndex) {
|
||||
output_ref = input_ref;
|
||||
@ -319,7 +326,7 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, output_shape, &output));
|
||||
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, int64>(
|
||||
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
|
||||
out_backprop, params, true);
|
||||
}
|
||||
@ -900,22 +907,22 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Targmax>
|
||||
struct LaunchMaxPoolingWithArgmax;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
|
||||
template <typename T, typename Targmax>
|
||||
struct LaunchMaxPoolingWithArgmax<CPUDevice, T, Targmax> {
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
const Tensor& input, Tensor* output, Tensor* argmax,
|
||||
bool propagate_nans, bool include_batch_in_index) {
|
||||
Tensor unused;
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(context, output, argmax,
|
||||
nullptr, input, unused, params,
|
||||
include_batch_in_index);
|
||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, Targmax>(
|
||||
context, output, argmax, /*input_backprop=*/nullptr, input, unused,
|
||||
params, include_batch_in_index);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Targmax>
|
||||
class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
public:
|
||||
explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
|
||||
@ -959,7 +966,7 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||
Tensor* argmax = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
|
||||
|
||||
LaunchMaxPoolingWithArgmax<Device, T>::launch(
|
||||
LaunchMaxPoolingWithArgmax<Device, T, Targmax>::launch(
|
||||
context, params, tensor_in, output, argmax, propagate_nans_,
|
||||
include_batch_in_index_);
|
||||
}
|
||||
@ -1027,6 +1034,7 @@ struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op.
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||
public:
|
||||
@ -1363,7 +1371,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
|
||||
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T, int64> {
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
const Tensor& input, Tensor* output, Tensor* argmax,
|
||||
bool propagate_nans, bool include_batch_in_index) {
|
||||
@ -1456,7 +1464,7 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<D##Device, T>); \
|
||||
MaxPoolingWithArgmaxOp<D##Device, T, int64>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T") \
|
||||
@ -1470,7 +1478,12 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
||||
MaxPoolingOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingV2Op<CPUDevice, T>);
|
||||
MaxPoolingV2Op<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<int32>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<CPUDevice, T, int32>);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
|
||||
#undef REGISTER_CPU_ONLY_POOL_KERNELS
|
||||
|
||||
|
@ -1004,12 +1004,14 @@ class PoolingTest(test.TestCase):
|
||||
]
|
||||
|
||||
Config = collections.namedtuple(
|
||||
"Config", ["use_gpu", "include_batch_in_index", "argmax"])
|
||||
"Config", ["use_gpu", "include_batch_in_index", "argmax", "Targmax"])
|
||||
configs = [
|
||||
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8]),
|
||||
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17]),
|
||||
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8]),
|
||||
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17])
|
||||
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
|
||||
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
|
||||
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int32),
|
||||
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int32),
|
||||
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
|
||||
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
@ -1019,7 +1021,7 @@ class PoolingTest(test.TestCase):
|
||||
t,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 1, 1, 1],
|
||||
Targmax=dtypes.int64,
|
||||
Targmax=config.Targmax,
|
||||
padding="VALID",
|
||||
include_batch_in_index=config.include_batch_in_index)
|
||||
out, argmax = self.evaluate([out_op, argmax_op])
|
||||
|
Loading…
Reference in New Issue
Block a user