Add int64 axis support for tf.cumsum and tf.cumprod (#13791)

* Add `int64` axis support for `tf.cumsum` and `tf.cumprod`

This fix adds `int64` axis support for `tf.cumsum` and `tf.cumprod`.
Though `int64` is the registered data type for `axis` (`Tidx`), no
kernel is available.

The issue could be described as:
```
>>> import tensorflow as tf
>>> v = tf.cumsum([1, 2, 3], tf.constant(0, tf.int64))
>>> tf.Session().run(v)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 889, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1317, in _do_run
    options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'Cumsum' with these attrs.  Registered devices: [CPU], Registered kernels:
  device='CPU'; T in [DT_COMPLEX128]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_COMPLEX64]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_DOUBLE]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_FLOAT]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_HALF]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT8]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_UINT8]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT16]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_UINT16]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT32]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT64]; Tidx in [DT_INT32]

	 [[Node: Cumsum = Cumsum[T=DT_INT32, Tidx=DT_INT64, exclusive=false, reverse=false](Cumsum/x, Const)]]

Caused by op u'Cumsum', defined at:
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 2246, in cumsum
    x, axis, exclusive=exclusive, reverse=reverse, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 1370, in cumsum
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2966, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1473, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): No OpKernel was registered to support Op 'Cumsum' with these attrs.  Registered devices: [CPU], Registered kernels:
  device='CPU'; T in [DT_COMPLEX128]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_COMPLEX64]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_DOUBLE]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_FLOAT]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_HALF]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT8]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_UINT8]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT16]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_UINT16]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT32]; Tidx in [DT_INT32]
  device='CPU'; T in [DT_INT64]; Tidx in [DT_INT32]

	 [[Node: Cumsum = Cumsum[T=DT_INT32, Tidx=DT_INT64, exclusive=false, reverse=false](Cumsum/x, Const)]]

>>>
```

This fix adds the missing kernel.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for `int64` axis support of `tf.cumsum` and `tf.cumprod`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Reformat scan_ops.cc with `clang-format -i`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-18 09:18:00 -07:00 committed by Vijay Vasudevan
parent bedfe8ac14
commit 139e1e0771
2 changed files with 80 additions and 36 deletions

View File

@ -35,7 +35,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, class T, typename Reducer>
template <typename Device, class T, typename Reducer, typename Tidx>
class ScanOp : public OpKernel {
public:
explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -51,8 +51,9 @@ class ScanOp : public OpKernel {
errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
tensor_axis.shape().DebugString()));
const int axis_arg = internal::SubtleMustCopy(tensor_axis.scalar<int>()());
const int axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
const Tidx axis_arg =
internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()());
const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()),
errors::InvalidArgument(
"ScanOp: Expected scan axis in the range [", -input.dims(),
@ -70,11 +71,11 @@ class ScanOp : public OpKernel {
// Dim reduction.
int64 reduced_shape[3] = {1, 1, 1};
for (int i = 0; i < axis; ++i) {
for (Tidx i = 0; i < axis; ++i) {
reduced_shape[0] *= input.dim_size(i);
}
reduced_shape[1] = input.dim_size(axis);
for (int i = axis + 1; i < input.dims(); ++i) {
for (Tidx i = axis + 1; i < input.dims(); ++i) {
reduced_shape[2] *= input.dim_size(i);
}
@ -112,51 +113,76 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
} // namespace functor
#endif // GOOGLE_CUDA
// Register Cumsum kernels
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>>)
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>>)
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
// Register Cumprod kernels
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>)
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>)
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradient_checker
@ -92,6 +94,14 @@ class CumsumTest(test.TestCase):
for axis in (-1, 0):
self._compareAll(x, axis)
def testAxisType(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True):
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumsum(x, axis).eval()
def test1D(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
@ -190,6 +200,14 @@ class CumprodTest(test.TestCase):
for axis in (-1, 0):
self._compareAll(x, axis)
def testAxisType(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True):
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumprod(x, axis).eval()
def test1D(self):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)