Support int32 argmax output in MaxPoolWithArgmax
The documentation (https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax) indicates that output_dtype supports both int32 and int64 but currently only kernels for int64 are registered. PiperOrigin-RevId: 347736377 Change-Id: I01ba684fb1486d4f88b47c29d6a6d5f7b102578c
This commit is contained in:
parent
c1e9462055
commit
08cb2c1d43
@ -19,7 +19,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/maxpooling_op.h"
|
#include "tensorflow/core/kernels/maxpooling_op.h"
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
@ -56,7 +58,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
|
|
||||||
const int kInvalidMaxPoolingIndex = -1;
|
const int kInvalidMaxPoolingIndex = -1;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename Targmax>
|
||||||
static void SpatialMaxPoolWithArgMaxHelper(
|
static void SpatialMaxPoolWithArgMaxHelper(
|
||||||
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
|
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
|
||||||
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
|
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
|
||||||
@ -67,13 +69,17 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
|||||||
errors::Internal(
|
errors::Internal(
|
||||||
"SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
|
"SpatialMaxPoolWithArgMaxHelper requires include_batch_in_index "
|
||||||
"to be True when input_backprop != nullptr"));
|
"to be True when input_backprop != nullptr"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, (std::is_same<Targmax, int64>::value),
|
||||||
|
errors::Internal("SpatialMaxPoolWithArgMaxHelper requires Targmax "
|
||||||
|
"to be int64 when input_backprop != nullptr"));
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
ConstEigenMatrixMap;
|
ConstEigenMatrixMap;
|
||||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
EigenMatrixMap;
|
EigenMatrixMap;
|
||||||
typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
|
typedef Eigen::Map<Eigen::Matrix<Targmax, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
EigenIndexMatrixMap;
|
EigenIndexMatrixMap;
|
||||||
|
|
||||||
ConstEigenMatrixMap in_mat(
|
ConstEigenMatrixMap in_mat(
|
||||||
@ -83,7 +89,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
|||||||
output->flat<T>().data(), params.depth,
|
output->flat<T>().data(), params.depth,
|
||||||
params.out_width * params.out_height * params.tensor_in_batch);
|
params.out_width * params.out_height * params.tensor_in_batch);
|
||||||
EigenIndexMatrixMap out_arg_max_mat(
|
EigenIndexMatrixMap out_arg_max_mat(
|
||||||
output_arg_max->flat<int64>().data(), params.depth,
|
output_arg_max->flat<Targmax>().data(), params.depth,
|
||||||
params.out_width * params.out_height * params.tensor_in_batch);
|
params.out_width * params.out_height * params.tensor_in_batch);
|
||||||
|
|
||||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||||
@ -150,7 +156,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
|
|||||||
for (int d = 0; d < depth; ++d) {
|
for (int d = 0; d < depth; ++d) {
|
||||||
const T& input_ref = in_mat.coeffRef(d, in_index);
|
const T& input_ref = in_mat.coeffRef(d, in_index);
|
||||||
T& output_ref = out_mat.coeffRef(d, out_index);
|
T& output_ref = out_mat.coeffRef(d, out_index);
|
||||||
int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
|
Targmax& out_arg_max_ref =
|
||||||
|
out_arg_max_mat.coeffRef(d, out_index);
|
||||||
if (output_ref < input_ref ||
|
if (output_ref < input_ref ||
|
||||||
out_arg_max_ref == kInvalidMaxPoolingIndex) {
|
out_arg_max_ref == kInvalidMaxPoolingIndex) {
|
||||||
output_ref = input_ref;
|
output_ref = input_ref;
|
||||||
@ -319,7 +326,7 @@ class MaxPoolingGradOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{0}, 0, output_shape, &output));
|
{0}, 0, output_shape, &output));
|
||||||
|
|
||||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
|
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, int64>(
|
||||||
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
|
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
|
||||||
out_backprop, params, true);
|
out_backprop, params, true);
|
||||||
}
|
}
|
||||||
@ -900,22 +907,22 @@ class MaxPoolingNoMaskV2Op : public OpKernel {
|
|||||||
TensorFormat data_format_;
|
TensorFormat data_format_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename Targmax>
|
||||||
struct LaunchMaxPoolingWithArgmax;
|
struct LaunchMaxPoolingWithArgmax;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename Targmax>
|
||||||
struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
|
struct LaunchMaxPoolingWithArgmax<CPUDevice, T, Targmax> {
|
||||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||||
const Tensor& input, Tensor* output, Tensor* argmax,
|
const Tensor& input, Tensor* output, Tensor* argmax,
|
||||||
bool propagate_nans, bool include_batch_in_index) {
|
bool propagate_nans, bool include_batch_in_index) {
|
||||||
Tensor unused;
|
Tensor unused;
|
||||||
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(context, output, argmax,
|
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T, Targmax>(
|
||||||
nullptr, input, unused, params,
|
context, output, argmax, /*input_backprop=*/nullptr, input, unused,
|
||||||
include_batch_in_index);
|
params, include_batch_in_index);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename Targmax>
|
||||||
class MaxPoolingWithArgmaxOp : public OpKernel {
|
class MaxPoolingWithArgmaxOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
|
explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
|
||||||
@ -959,7 +966,7 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
|
|||||||
Tensor* argmax = nullptr;
|
Tensor* argmax = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
|
OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
|
||||||
|
|
||||||
LaunchMaxPoolingWithArgmax<Device, T>::launch(
|
LaunchMaxPoolingWithArgmax<Device, T, Targmax>::launch(
|
||||||
context, params, tensor_in, output, argmax, propagate_nans_,
|
context, params, tensor_in, output, argmax, propagate_nans_,
|
||||||
include_batch_in_index_);
|
include_batch_in_index_);
|
||||||
}
|
}
|
||||||
@ -1027,6 +1034,7 @@ struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO(b/175733711): Support int32 argmax type in MaxPoolGradWithArgmax op.
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -1363,7 +1371,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
|
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T, int64> {
|
||||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||||
const Tensor& input, Tensor* output, Tensor* argmax,
|
const Tensor& input, Tensor* output, Tensor* argmax,
|
||||||
bool propagate_nans, bool include_batch_in_index) {
|
bool propagate_nans, bool include_batch_in_index) {
|
||||||
@ -1456,7 +1464,7 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
|||||||
.Device(DEVICE_##D) \
|
.Device(DEVICE_##D) \
|
||||||
.TypeConstraint<int64>("Targmax") \
|
.TypeConstraint<int64>("Targmax") \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
MaxPoolingWithArgmaxOp<D##Device, T>); \
|
MaxPoolingWithArgmaxOp<D##Device, T, int64>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
||||||
.Device(DEVICE_##D) \
|
.Device(DEVICE_##D) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
@ -1470,7 +1478,12 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
|
|||||||
MaxPoolingOp<CPUDevice, T>); \
|
MaxPoolingOp<CPUDevice, T>); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
MaxPoolingV2Op<CPUDevice, T>);
|
MaxPoolingV2Op<CPUDevice, T>); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<int32>("Targmax") \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
MaxPoolingWithArgmaxOp<CPUDevice, T, int32>);
|
||||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
|
||||||
#undef REGISTER_CPU_ONLY_POOL_KERNELS
|
#undef REGISTER_CPU_ONLY_POOL_KERNELS
|
||||||
|
|
||||||
|
@ -1004,12 +1004,14 @@ class PoolingTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
Config = collections.namedtuple(
|
Config = collections.namedtuple(
|
||||||
"Config", ["use_gpu", "include_batch_in_index", "argmax"])
|
"Config", ["use_gpu", "include_batch_in_index", "argmax", "Targmax"])
|
||||||
configs = [
|
configs = [
|
||||||
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8]),
|
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
|
||||||
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17]),
|
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
|
||||||
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8]),
|
Config(False, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int32),
|
||||||
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17])
|
Config(False, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int32),
|
||||||
|
Config(True, False, [0, 1, 3, 5, 0, 2, 6, 8], dtypes.int64),
|
||||||
|
Config(True, True, [0, 1, 3, 5, 9, 11, 15, 17], dtypes.int64),
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in configs:
|
for config in configs:
|
||||||
@ -1019,7 +1021,7 @@ class PoolingTest(test.TestCase):
|
|||||||
t,
|
t,
|
||||||
ksize=[1, 2, 2, 1],
|
ksize=[1, 2, 2, 1],
|
||||||
strides=[1, 1, 1, 1],
|
strides=[1, 1, 1, 1],
|
||||||
Targmax=dtypes.int64,
|
Targmax=config.Targmax,
|
||||||
padding="VALID",
|
padding="VALID",
|
||||||
include_batch_in_index=config.include_batch_in_index)
|
include_batch_in_index=config.include_batch_in_index)
|
||||||
out, argmax = self.evaluate([out_op, argmax_op])
|
out, argmax = self.evaluate([out_op, argmax_op])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user