Fix crash when int64 axis is passed to tf.reduce_sum (#13863)

* Fix crash when `int64` axis is passed to `tf.reduce_sum`

This fix tries to fix the crash triggered by `int64` axis passed
to `tf.reduce_sum`:
```
ubuntu@ubuntu:~/tensorflow2$ (cd && python)
Python 2.7.12 (default, Nov 19 2016, 06:48:10)
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> v = tf.reduce_sum([1,2,3], tf.constant(0, tf.int64))
2017-10-20 15:55:06.993430: F tensorflow/core/framework/tensor.cc:601] Check failed: dtype() == expected_dtype (9 vs. 3)
ubuntu@ubuntu:~/tensorflow2$
```

The issue is caused by the fact that shape inference in `common_shape_fns.cc`
only assumes int32 without proper handling of diffent types. In `math_ops.cc`
both int32 and int64 are mentioned.

NOTE that this fix does not address the issue that int64 is not supported.
To allow int64 axis it is more than adding a template in `ReductionOp` as the type
of the axis seems to be decided by some other ways in Eigen.

This fix merely fixed the crash so that an error message will return without
exit from the python program "No OpKernel was registered to support Op 'Sum' with these attrs".

Still, I think its worth to at least allow the program to continue in case of unsupported kernel.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Update implementation with a template helper function.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-20 17:40:40 -07:00 committed by Vijay Vasudevan
parent f758b24a82
commit df8bce63d6

View File

@ -1020,6 +1020,29 @@ Status UnknownShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
template <typename T>
Status ReductionShapeHelper(const Tensor* reduction_indices_t,
const int32 input_rank,
std::set<int64>& true_indices) {
auto reduction_indices = reduction_indices_t->flat<T>();
for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
const T reduction_index = reduction_indices(i);
if (reduction_index < -input_rank || reduction_index >= input_rank) {
return errors::InvalidArgument("Invalid reduction dimension ",
reduction_index, " for input with ",
input_rank, " dimensions.");
}
auto wrapped_index = reduction_index;
if (wrapped_index < 0) {
wrapped_index += input_rank;
}
true_indices.insert(wrapped_index);
}
return Status::OK();
}
Status ReductionShape(InferenceContext* c) {
ShapeHandle input = c->input(0);
@ -1050,22 +1073,16 @@ Status ReductionShape(InferenceContext* c) {
}
const int32 input_rank = c->Rank(input);
std::set<int32> true_indices;
auto reduction_indices = reduction_indices_t->flat<int32>();
for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
int32 reduction_index = reduction_indices(i);
if (reduction_index < -input_rank || reduction_index >= input_rank) {
return errors::InvalidArgument("Invalid reduction dimension ",
reduction_index, " for input with ",
input_rank, " dimensions.");
}
int32 wrapped_index = reduction_index;
if (wrapped_index < 0) {
wrapped_index += input_rank;
}
true_indices.insert(wrapped_index);
std::set<int64> true_indices;
if (reduction_indices_t->dtype() == DataType::DT_INT32) {
TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
input_rank, true_indices));
} else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
input_rank, true_indices));
} else {
return errors::InvalidArgument(
"reduction_indices can only be int32 or int64");
}
std::vector<DimensionHandle> dims;
@ -1319,11 +1336,10 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The outer ", num_outer_dims,
" dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", num_outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
"The outer ", num_outer_dims, " dimensions of indices.shape=",
c->DebugString(indices_shape), " must match the outer ",
num_outer_dims, " dimensions of updates.shape=",
c->DebugString(updates_shape), ": ", s.error_message());
}
ShapeHandle input_suffix;