diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 4796c3c00a4..315c99d32bf 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1020,6 +1020,29 @@ Status UnknownShape(shape_inference::InferenceContext* c) { return Status::OK(); } +template <typename T> +Status ReductionShapeHelper(const Tensor* reduction_indices_t, + const int32 input_rank, + std::set<int64>& true_indices) { + auto reduction_indices = reduction_indices_t->flat<T>(); + for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { + const T reduction_index = reduction_indices(i); + if (reduction_index < -input_rank || reduction_index >= input_rank) { + return errors::InvalidArgument("Invalid reduction dimension ", + reduction_index, " for input with ", + input_rank, " dimensions."); + } + + auto wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + + true_indices.insert(wrapped_index); + } + return Status::OK(); +} + Status ReductionShape(InferenceContext* c) { ShapeHandle input = c->input(0); @@ -1050,22 +1073,16 @@ Status ReductionShape(InferenceContext* c) { } const int32 input_rank = c->Rank(input); - std::set<int32> true_indices; - auto reduction_indices = reduction_indices_t->flat<int32>(); - for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { - int32 reduction_index = reduction_indices(i); - if (reduction_index < -input_rank || reduction_index >= input_rank) { - return errors::InvalidArgument("Invalid reduction dimension ", - reduction_index, " for input with ", - input_rank, " dimensions."); - } - - int32 wrapped_index = reduction_index; - if (wrapped_index < 0) { - wrapped_index += input_rank; - } - - true_indices.insert(wrapped_index); + std::set<int64> true_indices; + if (reduction_indices_t->dtype() == DataType::DT_INT32) { + TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t, + input_rank, true_indices)); + } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { + TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t, + input_rank, true_indices)); + } else { + return errors::InvalidArgument( + "reduction_indices can only be int32 or int64"); } std::vector<DimensionHandle> dims; @@ -1319,11 +1336,10 @@ Status ScatterNdUpdateShape(InferenceContext* c) { Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", num_outer_dims, - " dimensions of indices.shape=", c->DebugString(indices_shape), - " must match the outer ", num_outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); + "The outer ", num_outer_dims, " dimensions of indices.shape=", + c->DebugString(indices_shape), " must match the outer ", + num_outer_dims, " dimensions of updates.shape=", + c->DebugString(updates_shape), ": ", s.error_message()); } ShapeHandle input_suffix;