Add int64
axis support for tf.cumsum
and tf.cumprod
(#13791)
* Add `int64` axis support for `tf.cumsum` and `tf.cumprod` This fix adds `int64` axis support for `tf.cumsum` and `tf.cumprod`. Though `int64` is the registered data type for `axis` (`Tidx`), no kernel is available. The issue could be described as: ``` >>> import tensorflow as tf >>> v = tf.cumsum([1, 2, 3], tf.constant(0, tf.int64)) >>> tf.Session().run(v) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 889, in run run_metadata_ptr) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1120, in _run feed_dict_tensor, options, run_metadata) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1317, in _do_run options, run_metadata) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1336, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'Cumsum' with these attrs. Registered devices: [CPU], Registered kernels: device='CPU'; T in [DT_COMPLEX128]; Tidx in [DT_INT32] device='CPU'; T in [DT_COMPLEX64]; Tidx in [DT_INT32] device='CPU'; T in [DT_DOUBLE]; Tidx in [DT_INT32] device='CPU'; T in [DT_FLOAT]; Tidx in [DT_INT32] device='CPU'; T in [DT_HALF]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT8]; Tidx in [DT_INT32] device='CPU'; T in [DT_UINT8]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT16]; Tidx in [DT_INT32] device='CPU'; T in [DT_UINT16]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT32]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT64]; Tidx in [DT_INT32] [[Node: Cumsum = Cumsum[T=DT_INT32, Tidx=DT_INT64, exclusive=false, reverse=false](Cumsum/x, Const)]] Caused by op u'Cumsum', defined at: File "<stdin>", line 1, in <module> File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 2246, in cumsum x, axis, exclusive=exclusive, reverse=reverse, name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 1370, in cumsum name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2966, in create_op op_def=op_def) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1473, in __init__ self._traceback = self._graph._extract_stack() # pylint: disable=protected-access InvalidArgumentError (see above for traceback): No OpKernel was registered to support Op 'Cumsum' with these attrs. Registered devices: [CPU], Registered kernels: device='CPU'; T in [DT_COMPLEX128]; Tidx in [DT_INT32] device='CPU'; T in [DT_COMPLEX64]; Tidx in [DT_INT32] device='CPU'; T in [DT_DOUBLE]; Tidx in [DT_INT32] device='CPU'; T in [DT_FLOAT]; Tidx in [DT_INT32] device='CPU'; T in [DT_HALF]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT8]; Tidx in [DT_INT32] device='CPU'; T in [DT_UINT8]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT16]; Tidx in [DT_INT32] device='CPU'; T in [DT_UINT16]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT32]; Tidx in [DT_INT32] device='CPU'; T in [DT_INT64]; Tidx in [DT_INT32] [[Node: Cumsum = Cumsum[T=DT_INT32, Tidx=DT_INT64, exclusive=false, reverse=false](Cumsum/x, Const)]] >>> ``` This fix adds the missing kernel. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for `int64` axis support of `tf.cumsum` and `tf.cumprod` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Reformat scan_ops.cc with `clang-format -i` Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
bedfe8ac14
commit
139e1e0771
@ -35,7 +35,7 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, class T, typename Reducer>
|
||||
template <typename Device, class T, typename Reducer, typename Tidx>
|
||||
class ScanOp : public OpKernel {
|
||||
public:
|
||||
explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -51,8 +51,9 @@ class ScanOp : public OpKernel {
|
||||
errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
|
||||
tensor_axis.shape().DebugString()));
|
||||
|
||||
const int axis_arg = internal::SubtleMustCopy(tensor_axis.scalar<int>()());
|
||||
const int axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
|
||||
const Tidx axis_arg =
|
||||
internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()());
|
||||
const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()),
|
||||
errors::InvalidArgument(
|
||||
"ScanOp: Expected scan axis in the range [", -input.dims(),
|
||||
@ -70,11 +71,11 @@ class ScanOp : public OpKernel {
|
||||
|
||||
// Dim reduction.
|
||||
int64 reduced_shape[3] = {1, 1, 1};
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
for (Tidx i = 0; i < axis; ++i) {
|
||||
reduced_shape[0] *= input.dim_size(i);
|
||||
}
|
||||
reduced_shape[1] = input.dim_size(axis);
|
||||
for (int i = axis + 1; i < input.dims(); ++i) {
|
||||
for (Tidx i = axis + 1; i < input.dims(); ++i) {
|
||||
reduced_shape[2] *= input.dim_size(i);
|
||||
}
|
||||
|
||||
@ -112,51 +113,76 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
|
||||
} // namespace functor
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
||||
// Register Cumsum kernels
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>>)
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>>)
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumsum") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Register Cumprod kernels
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>)
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx"), \
|
||||
ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>)
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Cumprod") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -92,6 +94,14 @@ class CumsumTest(test.TestCase):
|
||||
for axis in (-1, 0):
|
||||
self._compareAll(x, axis)
|
||||
|
||||
def testAxisType(self):
|
||||
for dtype in self.valid_dtypes:
|
||||
x = np.arange(1, 6).reshape([5]).astype(dtype)
|
||||
for axis_dtype in [dtypes.int64, dtypes.int32]:
|
||||
with self.test_session(use_gpu=True):
|
||||
axis = constant_op.constant(0, axis_dtype)
|
||||
tf_out = math_ops.cumsum(x, axis).eval()
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.valid_dtypes:
|
||||
x = np.arange(1, 6).reshape([5]).astype(dtype)
|
||||
@ -190,6 +200,14 @@ class CumprodTest(test.TestCase):
|
||||
for axis in (-1, 0):
|
||||
self._compareAll(x, axis)
|
||||
|
||||
def testAxisType(self):
|
||||
for dtype in self.valid_dtypes:
|
||||
x = np.arange(1, 6).reshape([5]).astype(dtype)
|
||||
for axis_dtype in [dtypes.int64, dtypes.int32]:
|
||||
with self.test_session(use_gpu=True):
|
||||
axis = constant_op.constant(0, axis_dtype)
|
||||
tf_out = math_ops.cumprod(x, axis).eval()
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.valid_dtypes:
|
||||
x = np.arange(1, 6).reshape([5]).astype(dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user