diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 4796c3c00a4..315c99d32bf 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1020,6 +1020,29 @@ Status UnknownShape(shape_inference::InferenceContext* c) {
   return Status::OK();
 }
 
+template <typename T>
+Status ReductionShapeHelper(const Tensor* reduction_indices_t,
+                            const int32 input_rank,
+                            std::set<int64>& true_indices) {
+  auto reduction_indices = reduction_indices_t->flat<T>();
+  for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
+    const T reduction_index = reduction_indices(i);
+    if (reduction_index < -input_rank || reduction_index >= input_rank) {
+      return errors::InvalidArgument("Invalid reduction dimension ",
+                                     reduction_index, " for input with ",
+                                     input_rank, " dimensions.");
+    }
+
+    auto wrapped_index = reduction_index;
+    if (wrapped_index < 0) {
+      wrapped_index += input_rank;
+    }
+
+    true_indices.insert(wrapped_index);
+  }
+  return Status::OK();
+}
+
 Status ReductionShape(InferenceContext* c) {
   ShapeHandle input = c->input(0);
 
@@ -1050,22 +1073,16 @@ Status ReductionShape(InferenceContext* c) {
   }
 
   const int32 input_rank = c->Rank(input);
-  std::set<int32> true_indices;
-  auto reduction_indices = reduction_indices_t->flat<int32>();
-  for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
-    int32 reduction_index = reduction_indices(i);
-    if (reduction_index < -input_rank || reduction_index >= input_rank) {
-      return errors::InvalidArgument("Invalid reduction dimension ",
-                                     reduction_index, " for input with ",
-                                     input_rank, " dimensions.");
-    }
-
-    int32 wrapped_index = reduction_index;
-    if (wrapped_index < 0) {
-      wrapped_index += input_rank;
-    }
-
-    true_indices.insert(wrapped_index);
+  std::set<int64> true_indices;
+  if (reduction_indices_t->dtype() == DataType::DT_INT32) {
+    TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
+                                                   input_rank, true_indices));
+  } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
+    TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
+                                                   input_rank, true_indices));
+  } else {
+    return errors::InvalidArgument(
+        "reduction_indices can only be int32 or int64");
   }
 
   std::vector<DimensionHandle> dims;
@@ -1319,11 +1336,10 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
       if (!s.ok()) {
         return errors::InvalidArgument(
-            "The outer ", num_outer_dims,
-            " dimensions of indices.shape=", c->DebugString(indices_shape),
-            " must match the outer ", num_outer_dims,
-            " dimensions of updates.shape=", c->DebugString(updates_shape),
-            ": ", s.error_message());
+            "The outer ", num_outer_dims, " dimensions of indices.shape=",
+            c->DebugString(indices_shape), " must match the outer ",
+            num_outer_dims, " dimensions of updates.shape=",
+            c->DebugString(updates_shape), ": ", s.error_message());
       }
 
       ShapeHandle input_suffix;