diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index dc11f7caa2c..c1376c6a3d9 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -160,15 +160,24 @@ Status MakeErrorStatus(complex128 lhs, complex128 rhs,
 
 // A recursive function which iterates through every index of expected and
 // actual literal and compares their values elementwise. Returns true if all
-// elements are equal.
+// elements are equal. Mismatched must either be:
+//    - a literal of booleans that has the same shape as expected and actual. In
+//      this case, each index in mismatched will be set to true if expected does
+//      not equal actual at that index and false if there are equal.
+//    - nullptr. In this case, the function will return once any mismatch is
+//      found between expected and actual.
 template <typename NativeT>
 Status Equal(LiteralSlice expected, LiteralSlice actual,
-             absl::Span<int64> multi_index, int64 dimension) {
+             absl::Span<int64> multi_index, int64 dimension,
+             Literal* mismatched = nullptr) {
   if (dimension == expected.shape().dimensions_size()) {
     NativeT expected_value = expected.Get<NativeT>(multi_index);
     NativeT actual_value = actual.Get<NativeT>(multi_index);
     bool result =
         CompareEqual<NativeT>(expected_value, actual_value, multi_index);
+    if (mismatched) {
+      mismatched->Set<bool>(multi_index, !result);
+    }
     return result ? Status::OK()
                   : MakeErrorStatus<NativeT>(expected_value, actual_value,
                                              multi_index);
@@ -177,8 +186,13 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
   Status result;
   for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
     multi_index[dimension] = i;
-    TF_RETURN_IF_ERROR(
-        Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+    if (mismatched != nullptr) {
+      result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
+                                   mismatched));
+    } else {
+      TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
+                                        dimension + 1, mismatched));
+    }
   }
   return result;
 }
@@ -264,10 +278,11 @@ class NearComparator {
   // within the given error bound. In case of error, the status contains a
   // detailed message about the discrepancy.
   static Status Compare(const LiteralSlice& expected,
-                        const LiteralSlice& actual, ErrorSpec error,
+                        const LiteralSlice& actual,
+                        const ShapeIndex& shape_index, ErrorSpec error,
                         bool detailed_message,
                         const MiscompareCallback& miscompare_callback) {
-    NearComparator<NativeT> comparator(expected, actual, error,
+    NearComparator<NativeT> comparator(expected, actual, shape_index, error,
                                        detailed_message, miscompare_callback);
     return comparator.Run();
   }
@@ -300,10 +315,12 @@ class NearComparator {
   };
 
   NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
-                 ErrorSpec error, bool detailed_message,
+                 const ShapeIndex& shape_index, ErrorSpec error,
+                 bool detailed_message,
                  const MiscompareCallback& miscompare_callback)
       : expected_(expected),
         actual_(actual),
+        shape_index_(shape_index),
         error_(error),
         detailed_message_(detailed_message),
         miscompare_callback_(miscompare_callback),
@@ -329,7 +346,7 @@ class NearComparator {
     if (num_mismatches_ == 0) {
       return Status::OK();
     } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
-      miscompare_callback_(expected_, actual_, mismatches_);
+      miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
     }
     return InvalidArgument("%s", ErrorMessage());
   }
@@ -595,6 +612,9 @@ class NearComparator {
   LiteralSlice expected_;
   LiteralSlice actual_;
 
+  // The shape index of the LiteralSlice that is being compared.
+  ShapeIndex shape_index_;
+
   // The error bounds of the comparison.
   ErrorSpec error_;
 
@@ -653,70 +673,94 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
 template <typename NativeT>
 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 
-Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
+Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
+                   const ShapeIndex& shape_index,
+                   const MiscompareCallback& miscompare_callback) {
   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
-  std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
-  auto index = absl::MakeSpan(multi_index);
+
   Status result;
-  switch (expected.shape().element_type()) {
-    case PRED:
-      result = Equal<bool>(expected, actual, index, 0);
-      break;
-    case S8:
-      result = Equal<int8>(expected, actual, index, 0);
-      break;
-    case S16:
-      result = Equal<int16>(expected, actual, index, 0);
-      break;
-    case S32:
-      result = Equal<int32>(expected, actual, index, 0);
-      break;
-    case S64:
-      result = Equal<int64>(expected, actual, index, 0);
-      break;
-    case U8:
-      result = Equal<uint8>(expected, actual, index, 0);
-      break;
-    case U16:
-      result = Equal<uint16>(expected, actual, index, 0);
-      break;
-    case U32:
-      result = Equal<uint32>(expected, actual, index, 0);
-      break;
-    case U64:
-      result = Equal<uint64>(expected, actual, index, 0);
-      break;
-    case BF16:
-      result = Equal<bfloat16>(expected, actual, index, 0);
-      break;
-    case F16:
-      result = Equal<half>(expected, actual, index, 0);
-      break;
-    case F32:
-      result = Equal<float>(expected, actual, index, 0);
-      break;
-    case F64:
-      result = Equal<double>(expected, actual, index, 0);
-      break;
-    case C64:
-      result = Equal<complex64>(expected, actual, index, 0);
-      break;
-    case C128:
-      result = Equal<complex128>(expected, actual, index, 0);
-      break;
-    case TUPLE: {
-      for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
-        result.Update(EqualHelper(LiteralSlice(expected, {i}),
-                                  LiteralSlice(actual, {i})));
+  if (expected.shape().IsTuple()) {
+    ShapeIndex next_index = shape_index;
+    for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+      next_index.push_back(i);
+      Status tuple_result =
+          EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
+                      next_index, miscompare_callback);
+      if (miscompare_callback) {
+        result.Update(tuple_result);
+      } else {
+        TF_RETURN_IF_ERROR(tuple_result);
       }
-      break;
+      next_index.pop_back();
+    }
+  } else {
+    std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+    auto index = absl::MakeSpan(multi_index);
+
+    Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
+                                               expected.shape().dimensions());
+    Literal miscompared(unequal_shape);
+    Literal* miscompared_ptr =
+        (miscompare_callback == nullptr ? nullptr : &miscompared);
+
+    switch (expected.shape().element_type()) {
+      case PRED:
+        result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case S8:
+        result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case S16:
+        result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case S32:
+        result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case S64:
+        result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case U8:
+        result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case U16:
+        result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case U32:
+        result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case U64:
+        result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case BF16:
+        result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case F16:
+        result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case F32:
+        result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case F64:
+        result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case C64:
+        result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case C128:
+        result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
+        break;
+      case TOKEN:
+        // Tokens have no on-device representation and are trivially equal.
+        return Status::OK();
+      default:
+        LOG(FATAL) << "Unsupported primitive type: "
+                   << PrimitiveType_Name(expected.shape().element_type());
+    }
+
+    if (!result.ok() && miscompare_callback) {
+      miscompare_callback(expected, actual, LiteralSlice(miscompared),
+                          shape_index);
     }
-    case TOKEN:
-      // Tokens have no on-device representation and are trivially equal.
-      return Status::OK();
-    default:
-      LOG(FATAL) << "Unsupported primitive type: "
-                 << PrimitiveType_Name(expected.shape().element_type());
   }
 
   return result;
@@ -726,9 +770,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
 // via recursion. shape_index is the ShapeIndex of expected (or actual)
 // currently being compared.
 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
-                  const ErrorSpec& error, absl::optional<bool> detailed_message,
-                  const MiscompareCallback& miscompare_callback,
-                  const ShapeIndex& shape_index) {
+                  const ShapeIndex& shape_index, const ErrorSpec& error,
+                  absl::optional<bool> detailed_message,
+                  const MiscompareCallback& miscompare_callback) {
   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
 
   if (expected.shape().IsTuple()) {
@@ -739,8 +783,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
       ShapeIndex element_index = shape_index;
       element_index.push_back(i);
       Status element_result =
-          NearHelper(expected_element, actual_element, error, detailed_message,
-                     miscompare_callback, element_index);
+          NearHelper(expected_element, actual_element, element_index, error,
+                     detailed_message, miscompare_callback);
       if (!element_result.ok()) {
         element_result = InvalidArgument("Array at shape index %s, %s",
                                          element_index.ToString(),
@@ -771,28 +815,34 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
         ShapeUtil::ElementsIn(expected.shape()) >= 64);
     switch (expected.shape().element_type()) {
       case BF16:
-        return NearComparator<bfloat16>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+        return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
+                                                 error, use_detailed_message,
+                                                 miscompare_callback);
         break;
       case F16:
-        return NearComparator<half>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+        return NearComparator<half>::Compare(expected, actual, shape_index,
+                                             error, use_detailed_message,
+                                             miscompare_callback);
         break;
       case F32:
-        return NearComparator<float>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+        return NearComparator<float>::Compare(expected, actual, shape_index,
+                                              error, use_detailed_message,
+                                              miscompare_callback);
         break;
       case F64:
-        return NearComparator<double>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+        return NearComparator<double>::Compare(expected, actual, shape_index,
+                                               error, use_detailed_message,
+                                               miscompare_callback);
         break;
       case C64:
-        return NearComparator<complex64>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+        return NearComparator<complex64>::Compare(expected, actual, shape_index,
+                                                  error, use_detailed_message,
+                                                  miscompare_callback);
         break;
       case C128:
         return NearComparator<complex128>::Compare(
-            expected, actual, error, use_detailed_message, miscompare_callback);
+            expected, actual, shape_index, error, use_detailed_message,
+            miscompare_callback);
         break;
       default:
         LOG(FATAL) << "Unsupported primitive type in near comparator: "
@@ -802,7 +852,7 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
   }
 
   // Non-floating point, non-tuple literal.
-  return EqualHelper(expected, actual);
+  return EqualHelper(expected, actual, shape_index, miscompare_callback);
 }
 
 }  // namespace
@@ -878,7 +928,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
   XLA_VLOG_LINES(1, expected.ToString());
   VLOG(1) << "actual:";
   XLA_VLOG_LINES(1, actual.ToString());
-  Status result = EqualHelper(expected, actual);
+  Status result = EqualHelper(expected, actual, {}, nullptr);
   return EmitLiteralsInErrorMessage(result, expected, actual);
 }
 
@@ -889,9 +939,8 @@ Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
   XLA_VLOG_LINES(1, expected.ToString());
   VLOG(1) << "Actual literal:";
   XLA_VLOG_LINES(1, actual.ToString());
-  Status result =
-      NearHelper(expected, actual, error, detailed_message, miscompare_callback,
-                 /*shape_index=*/{});
+  Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
+                             detailed_message, miscompare_callback);
   return EmitLiteralsInErrorMessage(result, expected, actual);
 }
 
diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h
index 23fff3fa348..a8ed74e3704 100644
--- a/tensorflow/compiler/xla/literal_comparison.h
+++ b/tensorflow/compiler/xla/literal_comparison.h
@@ -35,9 +35,9 @@ Status EqualShapes(const Shape& expected, const Shape& actual);
 // primitive type are equal.
 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
 
-using MiscompareCallback =
-    std::function<void(const LiteralSlice& expected, const LiteralSlice& actual,
-                       const LiteralSlice& mismatches)>;
+using MiscompareCallback = std::function<void(
+    const LiteralSlice& expected, const LiteralSlice& actual,
+    const LiteralSlice& mismatches, const ShapeIndex& shape_index)>;
 
 // Inspects whether the expected and actual literals are within the given error
 // bound for all elements. Also, inspects whether the rank, dimensions sizes,
@@ -57,6 +57,9 @@ using MiscompareCallback =
 // If detailed_message is true, then the error message in the assertion result
 // will contain a more detailed breakdown of mismatches.  By default, we display
 // a detailed message only for "large" inputs.
+//
+// If miscompare_callback is nullptr, Near will return an error on the first
+// detected mismatch.
 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
             const ErrorSpec& error, absl::optional<bool> detailed_message,
             const MiscompareCallback& miscompare_callback);
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 7f725a97f28..4dd59cdca5d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -51,7 +51,8 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
 // Callback helper that dumps literals to temporary files in the event of a
 // miscomparison.
 void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
-                  const LiteralSlice& mismatches) {
+                  const LiteralSlice& mismatches,
+                  const ShapeIndex& /*shape_index*/) {
   LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
             << literal_comparison::ToStringTruncated(expected);
   LOG(INFO) << "actual:   " << ShapeUtil::HumanString(actual.shape()) << " "