diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 36ae36e7b74..8157f4ee01d 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -2257,66 +2257,57 @@ Status GatherNdShape(InferenceContext* c) {
   return Status::OK();
 }
 
-Status ScatterNdUpdateShape(InferenceContext* c) {
-  ShapeHandle input_shape = c->input(0);
-  if (c->input_handle_shapes_and_types(0) != nullptr) {
-    // This is called for tf.scatter_nd_update; input is a Variable handle.
-    const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
-    if (shape_and_type.size() == 1) {
-      input_shape = shape_and_type[0].shape;
-    }
-  }
-  ShapeHandle indices_shape;
-  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
-  ShapeHandle updates_shape;
-  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
-
+Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
+                            ShapeHandle updates_shape,
+                            ShapeHandle input_shape) {
   if (c->Value(c->NumElements(input_shape)) == 0 &&
       (c->Value(c->NumElements(indices_shape)) > 0 ||
        c->Value(c->NumElements(updates_shape)) > 0)) {
     return errors::InvalidArgument(
-        "Indices and updates specified for empty output shape");
+        "Indices and updates specified for empty input");
   }
 
   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
-    const int64 num_outer_dims = c->Rank(indices_shape) - 1;
-    const DimensionHandle index_size = c->Dim(indices_shape, -1);
+    const int64 outer_dims = c->Rank(indices_shape) - 1;
+    const DimensionHandle ixdim = c->Dim(indices_shape, -1);
 
     // We can only do more validation if the last dimension of indices
     // is a known value.
-    if (c->ValueKnown(index_size)) {
-      const int64 ix = c->Value(index_size);
+    if (c->ValueKnown(ixdim)) {
+      int64 ix = c->Value(ixdim);
       ShapeHandle unused;
       ShapeHandle prefix_indices;
       TF_RETURN_IF_ERROR(
-          c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
+          c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
       ShapeHandle prefix_updates;
       TF_RETURN_IF_ERROR(
-          c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
+          c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
 
       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
       if (!s.ok()) {
         return errors::InvalidArgument(
-            "The outer ", num_outer_dims,
-            " dimensions of indices.shape=", c->DebugString(indices_shape),
-            " must match the outer ", num_outer_dims,
-            " dimensions of updates.shape=", c->DebugString(updates_shape),
-            ": ", s.error_message());
+            "Dimensions [0,", outer_dims,
+            ") of indices[shape=", c->DebugString(indices_shape),
+            "] = ", c->DebugString(prefix_indices),
+            " must match dimensions [0,", outer_dims,
+            ") of updates[shape=", c->DebugString(updates_shape),
+            "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
       }
 
-      ShapeHandle input_suffix;
-      TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
+      ShapeHandle suffix_output;
+      TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
       ShapeHandle suffix_updates;
       TF_RETURN_IF_ERROR(
-          c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
-      s = c->Merge(input_suffix, suffix_updates, &unused);
+          c->Subshape(updates_shape, outer_dims, &suffix_updates));
+      s = c->Merge(suffix_output, suffix_updates, &unused);
       if (!s.ok()) {
         return errors::InvalidArgument(
-            "The inner ", c->Rank(input_shape) - ix,
-            " dimensions of input.shape=", c->DebugString(input_shape),
-            " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
-            " dimensions of updates.shape=", c->DebugString(updates_shape),
-            ": ", s.error_message());
+            "Dimensions [", ix, ",", c->Rank(input_shape),
+            ") of input[shape=", c->DebugString(input_shape),
+            "] = ", c->DebugString(suffix_output), " must match dimensions [",
+            outer_dims, ",", c->Rank(updates_shape),
+            ") of updates[shape=", c->DebugString(updates_shape),
+            "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
       }
     }
   }
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 218400c2435..f3e02638f54 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -241,8 +241,9 @@ Status ValidateVariableResourceHandle(
 // Shape function for GatherNd operations.
 Status GatherNdShape(InferenceContext* c);
 
-// Shape function for ScatterNd update/add/sub/... operations.
-Status ScatterNdUpdateShape(InferenceContext* c);
+// Helper shape function for ScatterNd.../TensorScatter... operations.
+Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
+                            ShapeHandle updates_shape, ShapeHandle input_shape);
 
 // Shape function for ops with an explicit "shape" attribute.
 Status ExplicitShape(InferenceContext* c);
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 88bf16d974e..942740b9af3 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -100,29 +100,31 @@ class ScatterNdOp : public OpKernel {
     const int64 outer_dims = indices.shape().dims() - 1;
 
     for (int i = 0; i < outer_dims; ++i) {
-      OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
-                  errors::InvalidArgument(
-                      "Outer dimensions of indices and update must match. "
-                      "Indices shape: ",
-                      indices.shape().DebugString(),
-                      ", updates shape:", updates.shape().DebugString()));
+      OP_REQUIRES(
+          c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
+          errors::InvalidArgument(
+              "Dimensions [0,", outer_dims,
+              ") of indices[shape=", indices.shape().DebugString(),
+              "] must match dimensions [0,", outer_dims,
+              ") of updates[shape=", updates.shape().DebugString(), "]"));
     }
 
     const int64 ix = indices.shape().dim_size(outer_dims);
-    OP_REQUIRES(
-        c, updates.shape().dims() - outer_dims == shape.dims() - ix,
-        errors::InvalidArgument("Inner dimensions of output shape must match "
-                                "inner dimensions of updates shape. Output: ",
-                                shape.DebugString(),
-                                " updates: ", updates.shape().DebugString()));
+    OP_REQUIRES(c, updates.shape().dims() - outer_dims == shape.dims() - ix,
+                errors::InvalidArgument(
+                    "Dimensions [", ix, ",", shape.dims(), ") of input[shape=",
+                    shape.DebugString(), "] must match dimensions [",
+                    outer_dims, ",", updates.shape().dims(),
+                    ") of updates[shape=", updates.shape().DebugString(), "]"));
+
     for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
       OP_REQUIRES(
           c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
-          errors::InvalidArgument(
-              "The inner ", shape.dims() - ix,
-              " dimensions of output.shape=", shape.DebugString(),
-              " must match the inner ", updates.shape().dims() - outer_dims,
-              " dimensions of updates.shape=", updates.shape().DebugString()));
+          errors::InvalidArgument("Dimensions [", ix, ",", shape.dims(),
+                                  ") of input[shape=", shape.DebugString(),
+                                  "] must match dimensions [", outer_dims, ",",
+                                  updates.shape().dims(), ") of updates[shape=",
+                                  updates.shape().DebugString(), "]"));
     }
     OP_REQUIRES(c, shape_input.dims() == 1,
                 errors::InvalidArgument("Shape must be a vector"));
@@ -602,30 +604,35 @@ Status ValidateUpdateShape(const TensorShape& params_shape,
       (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
   const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
 
-  auto shape_err = [&]() {
+  auto shape_err_prefix = [&]() {
     return errors::InvalidArgument(
-        "Must have updates.shape = indices.shape[:batch_dim] + ",
-        "params_shape[slice_dim:], got updates.shape: ",
-        updates.shape().DebugString(),
-        ", indices.shape: ", indices.shape().DebugString(),
-        ", params_shape: ", params_shape.DebugString(),
-        ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
+        "Dimensions [0,", batch_dim,
+        ") of indices[shape=", indices.shape().DebugString(),
+        "] must match dimensions [0,", batch_dim,
+        ") of updates[shape=", updates.shape().DebugString(), "]");
+  };
+  auto shape_err_suffix = [&]() {
+    return errors::InvalidArgument(
+        "Dimensions [", slice_dim, ",", params_shape.dims(),
+        ") of input[shape=", params_shape.DebugString(),
+        "] must match dimensions [", slice_dim, ",", updates.dims(),
+        ") of updates[shape=", updates.shape().DebugString(), "]");
   };
 
-  if (updates.dims() < batch_dim) return shape_err();
+  if (updates.dims() < batch_dim) return shape_err_prefix();
   if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
-    return shape_err();
+    return shape_err_suffix();
   }
   if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
-    return shape_err();
+    return shape_err_suffix();
   }
   for (int d = 0; d < batch_dim; ++d) {
-    if (updates.dim_size(d) != indices.dim_size(d)) return shape_err();
+    if (updates.dim_size(d) != indices.dim_size(d)) return shape_err_prefix();
   }
   for (int d = 0; d < updates.dims() - batch_dim; ++d) {
     if (updates.dim_size(d + batch_dim) !=
         params_shape.dim_size(d + slice_dim)) {
-      return shape_err();
+      return shape_err_suffix();
     }
   }
   return Status::OK();
@@ -654,9 +661,9 @@ Status PrepareAndValidateInputs(const TensorShape& params_shape,
 
   if (updates.dim_size(0) != indices.dim_size(0)) {
     return errors::InvalidArgument(
-        "The outermost dimension of updates and indices ",
-        "must match. Got indices.shape ", indices_shape.DebugString(),
-        ", updates.shape ", updates_shape.DebugString());
+        "Dimensions [0,1) of indices[shape=", indices_shape.DebugString(),
+        "] = ", indices.dim_size(0), " must match dimensions [0,1) of updates[",
+        "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0));
   }
   TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
 
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index 1461831a1fb..9c31bed784f 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -200,8 +200,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
   Status s = RunOpKernel();
   EXPECT_TRUE(absl::StrContains(
       s.ToString(),
-      "The outermost dimension of updates and indices must match. Got "
-      "indices.shape [1,3,1], updates.shape [3,3]"))
+      "Dimensions [0,1) of indices[shape=[1,3,1]] = 1 must match dimensions "
+      "[0,1) of updates[shape=[3,3]] = 3"))
       << s;
 }
 
@@ -217,7 +217,9 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
       {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
   Status s = RunOpKernel();
   EXPECT_TRUE(absl::StrContains(
-      s.ToString(), "Must have updates.shape = indices.shape[:batch_dim]"))
+      s.ToString(),
+      "Dimensions [1,2) of input[shape=[5,3]] must match dimensions [1,2) of "
+      "updates[shape=[3,4]]"))
       << s;
 }
 
@@ -233,7 +235,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
   Status s = RunOpKernel();
   EXPECT_TRUE(absl::StrContains(
       s.ToString(),
-      "The outermost dimension of updates and indices must match."))
+      "Dimensions [0,1) of indices[shape=[3,1]] = 3 must match dimensions [0,1)"
+      " of updates[shape=[2,3]] = 2"))
       << s;
 }
 
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 11bfb9a3346..b4dfe6187d5 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2974,73 +2974,6 @@ REGISTER_OP("QuantizedInstanceNorm")
 
 namespace {
 
-Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
-                            ShapeHandle updates_shape,
-                            ShapeHandle output_shape) {
-  if (c->Value(c->NumElements(output_shape)) == 0 &&
-      (c->Value(c->NumElements(indices_shape)) > 0 ||
-       c->Value(c->NumElements(updates_shape)) > 0)) {
-    return errors::InvalidArgument(
-        "Indices and updates specified for empty output shape");
-  }
-
-  if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
-    const int64 outer_dims = c->Rank(indices_shape) - 1;
-    const DimensionHandle ixdim = c->Dim(indices_shape, -1);
-
-    // We can only do more validation if the last dimension of indices
-    // is a known value.
-    if (c->ValueKnown(ixdim)) {
-      int64 ix = c->Value(ixdim);
-      ShapeHandle unused;
-      ShapeHandle prefix_indices;
-      TF_RETURN_IF_ERROR(
-          c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
-      ShapeHandle prefix_updates;
-      TF_RETURN_IF_ERROR(
-          c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
-
-      Status s = c->Merge(prefix_indices, prefix_updates, &unused);
-      if (!s.ok()) {
-        return errors::InvalidArgument(
-            "The outer ", outer_dims,
-            " dimensions of indices.shape=", c->DebugString(indices_shape),
-            " must match the outer ", outer_dims,
-            " dimensions of updates.shape=", c->DebugString(updates_shape),
-            ": ", s.error_message());
-      }
-
-      ShapeHandle suffix_output;
-      TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output));
-      ShapeHandle suffix_updates;
-      TF_RETURN_IF_ERROR(
-          c->Subshape(updates_shape, outer_dims, &suffix_updates));
-      s = c->Merge(suffix_output, suffix_updates, &unused);
-      if (!s.ok()) {
-        return errors::InvalidArgument(
-            "The inner ", c->Rank(output_shape) - ix,
-            " dimensions of output.shape=", c->DebugString(output_shape),
-            " must match the inner ", c->Rank(updates_shape) - outer_dims,
-            " dimensions of updates.shape=", c->DebugString(updates_shape),
-            ": ", s.error_message());
-      }
-    }
-  }
-
-  c->set_output(0, output_shape);
-  return Status::OK();
-}
-
-Status ScatterNdShape(InferenceContext* c) {
-  ShapeHandle indices_shape;
-  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
-  ShapeHandle updates_shape;
-  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
-  ShapeHandle output_shape;
-  TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
-  return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
-}
-
 Status ScatterNdTensorShape(InferenceContext* c) {
   ShapeHandle output_shape;
   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
@@ -3048,7 +2981,8 @@ Status ScatterNdTensorShape(InferenceContext* c) {
   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
   ShapeHandle updates_shape;
   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
-  return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
+  return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
+                                               output_shape);
 }
 
 }  // namespace
@@ -3088,7 +3022,16 @@ REGISTER_OP("ScatterNd")
     .Output("output: T")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
-    .SetShapeFn(ScatterNdShape);
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle indices_shape;
+      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
+      ShapeHandle updates_shape;
+      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
+      ShapeHandle output_shape;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
+      return shape_inference::ScatterNdShapeHelper(c, indices_shape,
+                                                   updates_shape, output_shape);
+    });
 
 REGISTER_OP("TensorScatterUpdate")
     .Input("tensor: T")
@@ -3142,7 +3085,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd")
     .Output("output: T")
     .Attr("T: {numbertype, bool}")
     .Attr("Tindices: {int32, int64}")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdTensorShape);
 
 REGISTER_OP("FakeQuantWithMinMaxArgs")
     .Attr("min: float = -6.0")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 1725bdbac39..412c926d386 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -27,6 +27,39 @@ limitations under the License.
 
 namespace tensorflow {
 
+TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) {
+  ShapeInferenceTestOp op("TensorScatterUpdate");
+
+  INFER_OK(op, "[4,3];[8,2];[8]", "in0");
+  INFER_OK(op, "[?,?];[?,2];[?]", "in0");
+  INFER_OK(op, "[?];[?];[?]", "in0");
+
+  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
+              "[];[?,2];[?]");
+  INFER_ERROR("Indices and updates specified for empty input", op,
+              "[0,2,2];[8,2];[8]");
+  INFER_ERROR(
+      "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
+      "dimensions [0,1) of updates[shape=[9]] = [9]",
+      op, "[?,?];[8,2];[9]");
+  INFER_ERROR(
+      "Dimensions [2,2) of input[shape=[?,?]] = [] must match "
+      "dimensions [1,2) of updates[shape=[?,1]] = [1]",
+      op, "[?,?];[?,2];[?,1]");
+}
+
+TEST(ArrayOpsTest, ScatterNd_ShapeFn) {
+  ShapeInferenceTestOp op("ScatterNd");
+
+  INFER_OK(op, "[8,2];[8];[2]", "[?,?]");
+
+  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]");
+  INFER_ERROR(
+      "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
+      "dimensions [0,1) of updates[shape=[9]] = [9]",
+      op, "[8,2];[9];[?]");
+}
+
 TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
   ShapeInferenceTestOp op("UnravelIndex");
 
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 500d5ec88b8..5d856396360 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -131,6 +131,22 @@ Status ScatterUpdateShape(InferenceContext* c) {
   return Status::OK();
 }
 
+Status ScatterNdUpdateShape(InferenceContext* c) {
+  ShapeHandle input_shape = c->input(0);
+  if (c->input_handle_shapes_and_types(0) != nullptr) {
+    const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
+    if (!shape_and_type.empty()) {
+      input_shape = shape_and_type[0].shape;
+    }
+  }
+  ShapeHandle indices_shape;
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
+  ShapeHandle updates_shape;
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
+  return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
+                                               input_shape);
+}
+
 }  // namespace
 
 REGISTER_OP("ScatterUpdate")
@@ -211,7 +227,7 @@ REGISTER_OP("ScatterNdUpdate")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ResourceScatterNdUpdate")
     .Input("ref: resource")
@@ -220,7 +236,7 @@ REGISTER_OP("ResourceScatterNdUpdate")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ResourceScatterNdAdd")
     .Input("ref: resource")
@@ -229,7 +245,7 @@ REGISTER_OP("ResourceScatterNdAdd")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ResourceScatterNdSub")
     .Input("ref: resource")
@@ -238,7 +254,7 @@ REGISTER_OP("ResourceScatterNdSub")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ResourceScatterNdMin")
     .Input("ref: resource")
@@ -247,7 +263,7 @@ REGISTER_OP("ResourceScatterNdMin")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ResourceScatterNdMax")
     .Input("ref: resource")
@@ -256,7 +272,7 @@ REGISTER_OP("ResourceScatterNdMax")
     .Attr("T: type")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = true")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ScatterNdAdd")
     .Input("ref: Ref(T)")
@@ -266,7 +282,7 @@ REGISTER_OP("ScatterNdAdd")
     .Attr("T: numbertype")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = false")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ScatterNdSub")
     .Input("ref: Ref(T)")
@@ -276,7 +292,7 @@ REGISTER_OP("ScatterNdSub")
     .Attr("T: numbertype")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = false")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ScatterNdMax")
     .Input("ref: Ref(T)")
@@ -286,7 +302,7 @@ REGISTER_OP("ScatterNdMax")
     .Attr("T: numbertype")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = false")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("ScatterNdMin")
     .Input("ref: Ref(T)")
@@ -296,7 +312,7 @@ REGISTER_OP("ScatterNdMin")
     .Attr("T: numbertype")
     .Attr("Tindices: {int32, int64}")
     .Attr("use_locking: bool = false")
-    .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+    .SetShapeFn(ScatterNdUpdateShape);
 
 REGISTER_OP("CountUpTo")
     .Input("ref: Ref(T)")
diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc
index a0caad4a49f..bc68cf46f03 100644
--- a/tensorflow/core/ops/state_ops_test.cc
+++ b/tensorflow/core/ops/state_ops_test.cc
@@ -69,6 +69,28 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
   INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
 }
 
+TEST(StateOpsTest, ResourceScatterNdUpdate_ShapeFn) {
+  ShapeInferenceTestOp op("ResourceScatterNdUpdate");
+  TF_ASSERT_OK(NodeDefBuilder("test", "ResourceScatterNdUpdate")
+                   .Input("ref", 0, DT_RESOURCE)
+                   .Input("indices", 0, DT_INT32)
+                   .Input("updates", 1, DT_FLOAT)
+                   .Finalize(&op.node_def));
+
+  std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types;
+  op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
+  op.input_resource_handle_shapes_and_types.push_back(nullptr);
+  op.input_resource_handle_shapes_and_types.push_back(nullptr);
+  shapes_and_types.emplace_back("[?,?]", DT_FLOAT);
+  INFER_OK(op, "[?];[?,2];[?]", "");
+  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
+              "[?];[?,2];[]");
+  INFER_ERROR(
+      "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
+      "dimensions [0,1) of updates[shape=[9]] = [9]",
+      op, "[?];[8,2];[9]");
+}
+
 TEST(StateOpsTest, TemporaryVariable_ShapeFn) {
   ShapeInferenceTestOp op("TemporaryVariable");
   TensorShape shape({1, 2, 3});
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index c5e5e549ee7..d5843c1a766 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -331,24 +331,24 @@ class StatefulScatterNdTest(test.TestCase):
       self.evaluate(ref.initializer)
       self.assertAllEqual(expected_result, self.evaluate(scatter_update))
 
-  @test_util.run_deprecated_v1
   def testRank3InvalidShape1(self):
     indices = array_ops.zeros([3, 2, 2], dtypes.int32)
     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
     shape = np.array([2, 2, 2])
     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
     with self.assertRaisesWithPredicateMatch(
-        ValueError, r"The outer \d+ dimensions of indices\.shape="):
+        (errors.InvalidArgumentError, ValueError),
+        r"Dimensions \[\d,\d\) of indices\[shape="):
       state_ops.scatter_nd_update(ref, indices, updates)
 
-  @test_util.run_deprecated_v1
   def testRank3InvalidShape2(self):
     indices = array_ops.zeros([2, 2, 1], dtypes.int32)
     updates = array_ops.zeros([2, 2], dtypes.int32)
     shape = np.array([2, 2, 2])
     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
     with self.assertRaisesWithPredicateMatch(
-        ValueError, r"The inner \d+ dimensions of input\.shape="):
+        (errors.InvalidArgumentError, ValueError),
+        r"Dimensions \[\d,\d\) of input\[shape="):
       state_ops.scatter_nd_update(ref, indices, updates)
 
   def testConcurrentUpdates(self):
@@ -511,14 +511,14 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
       shape = array_ops.placeholder(dtypes.int32, shape=[None])
       self.scatter_nd(indices, updates, shape)
 
-  @test_util.run_deprecated_v1
   def testEmptyOutputShape1(self):
     indices = array_ops.zeros([2, 2, 2], dtypes.int32)
     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
     shape = constant_op.constant([0, 3, 2], dtypes.int32)
 
     with self.assertRaisesWithPredicateMatch(
-        ValueError, "Indices and updates specified for empty output shape"):
+        (errors.InvalidArgumentError, ValueError),
+        "Indices and updates specified for empty"):
       self.scatter_nd(indices, updates, shape)
 
   def testEmptyOutputShape2(self):
@@ -529,7 +529,7 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
 
       with self.cached_session():
         with self.assertRaisesOpError(
-            "Indices and updates specified for empty output"):
+            "Indices and updates specified for empty (input|output)"):
           self.scatter_nd(indices, updates, shape).eval(
               feed_dict={
                   indices: np.zeros([2, 2, 2], dtype=np.int32),
@@ -545,22 +545,22 @@ class ScatterNdTest(test.TestCase, parameterized.TestCase):
     with self.cached_session():
       self.assertEqual(self.evaluate(scatter).size, 0)
 
-  @test_util.run_deprecated_v1
   def testRank3InvalidShape1(self):
     indices = array_ops.zeros([3, 2, 2], dtypes.int32)
     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
     shape = np.array([2, 2, 2])
     with self.assertRaisesWithPredicateMatch(
-        ValueError, r"The outer \d+ dimensions of indices\.shape="):
+        (errors.InvalidArgumentError, ValueError),
+        r"Dimensions \[\d\,\d\) of indices\[shape="):
       self.scatter_nd(indices, updates, shape)
 
-  @test_util.run_deprecated_v1
   def testRank3InvalidShape2(self):
     indices = array_ops.zeros([2, 2, 1], dtypes.int32)
     updates = array_ops.zeros([2, 2], dtypes.int32)
     shape = np.array([2, 2, 2])
     with self.assertRaisesWithPredicateMatch(
-        ValueError, r"The inner \d+ dimensions of (input|output)\.shape="):
+        (errors.InvalidArgumentError, ValueError),
+        r"Dimensions \[\d\,\d\) of input\[shape="):
       self.scatter_nd(indices, updates, shape)
 
   @parameterized.parameters(set((True, context.executing_eagerly())))