From df8bce63d6de6e728e69eb9f45862b816f88a0db Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 20 Oct 2017 17:40:40 -0700 Subject: [PATCH] Fix crash when `int64` axis is passed to `tf.reduce_sum` (#13863) * Fix crash when `int64` axis is passed to `tf.reduce_sum` This fix tries to fix the crash triggered by `int64` axis passed to `tf.reduce_sum`: ``` ubuntu@ubuntu:~/tensorflow2$ (cd && python) Python 2.7.12 (default, Nov 19 2016, 06:48:10) [GCC 5.4.0 20160609] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> import tensorflow as tf >>> v = tf.reduce_sum([1,2,3], tf.constant(0, tf.int64)) 2017-10-20 15:55:06.993430: F tensorflow/core/framework/tensor.cc:601] Check failed: dtype() == expected_dtype (9 vs. 3) ubuntu@ubuntu:~/tensorflow2$ ``` The issue is caused by the fact that shape inference in `common_shape_fns.cc` only assumes int32 without proper handling of diffent types. In `math_ops.cc` both int32 and int64 are mentioned. NOTE that this fix does not address the issue that int64 is not supported. To allow int64 axis it is more than adding a template in `ReductionOp` as the type of the axis seems to be decided by some other ways in Eigen. This fix merely fixed the crash so that an error message will return without exit from the python program "No OpKernel was registered to support Op 'Sum' with these attrs". Still, I think its worth to at least allow the program to continue in case of unsupported kernel. Signed-off-by: Yong Tang * Update implementation with a template helper function. Signed-off-by: Yong Tang --- tensorflow/core/framework/common_shape_fns.cc | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 4796c3c00a4..315c99d32bf 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1020,6 +1020,29 @@ Status UnknownShape(shape_inference::InferenceContext* c) { return Status::OK(); } +template +Status ReductionShapeHelper(const Tensor* reduction_indices_t, + const int32 input_rank, + std::set& true_indices) { + auto reduction_indices = reduction_indices_t->flat(); + for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { + const T reduction_index = reduction_indices(i); + if (reduction_index < -input_rank || reduction_index >= input_rank) { + return errors::InvalidArgument("Invalid reduction dimension ", + reduction_index, " for input with ", + input_rank, " dimensions."); + } + + auto wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + + true_indices.insert(wrapped_index); + } + return Status::OK(); +} + Status ReductionShape(InferenceContext* c) { ShapeHandle input = c->input(0); @@ -1050,22 +1073,16 @@ Status ReductionShape(InferenceContext* c) { } const int32 input_rank = c->Rank(input); - std::set true_indices; - auto reduction_indices = reduction_indices_t->flat(); - for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { - int32 reduction_index = reduction_indices(i); - if (reduction_index < -input_rank || reduction_index >= input_rank) { - return errors::InvalidArgument("Invalid reduction dimension ", - reduction_index, " for input with ", - input_rank, " dimensions."); - } - - int32 wrapped_index = reduction_index; - if (wrapped_index < 0) { - wrapped_index += input_rank; - } - - true_indices.insert(wrapped_index); + std::set true_indices; + if (reduction_indices_t->dtype() == DataType::DT_INT32) { + TF_RETURN_IF_ERROR(ReductionShapeHelper(reduction_indices_t, + input_rank, true_indices)); + } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { + TF_RETURN_IF_ERROR(ReductionShapeHelper(reduction_indices_t, + input_rank, true_indices)); + } else { + return errors::InvalidArgument( + "reduction_indices can only be int32 or int64"); } std::vector dims; @@ -1319,11 +1336,10 @@ Status ScatterNdUpdateShape(InferenceContext* c) { Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", num_outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "The outer ", num_outer_dims, " dimensions of indices.shape=", + c->DebugString(indices_shape), " must match the outer ", + num_outer_dims, " dimensions of updates.shape=", + c->DebugString(updates_shape), ": ", s.error_message()); } ShapeHandle input_suffix;