diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index dc11f7caa2c..c1376c6a3d9 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -160,15 +160,24 @@ Status MakeErrorStatus(complex128 lhs, complex128 rhs, // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all -// elements are equal. +// elements are equal. Mismatched must either be: +// - a literal of booleans that has the same shape as expected and actual. In +// this case, each index in mismatched will be set to true if expected does +// not equal actual at that index and false if there are equal. +// - nullptr. In this case, the function will return once any mismatch is +// found between expected and actual. template <typename NativeT> Status Equal(LiteralSlice expected, LiteralSlice actual, - absl::Span<int64> multi_index, int64 dimension) { + absl::Span<int64> multi_index, int64 dimension, + Literal* mismatched = nullptr) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get<NativeT>(multi_index); NativeT actual_value = actual.Get<NativeT>(multi_index); bool result = CompareEqual<NativeT>(expected_value, actual_value, multi_index); + if (mismatched) { + mismatched->Set<bool>(multi_index, !result); + } return result ? Status::OK() : MakeErrorStatus<NativeT>(expected_value, actual_value, multi_index); @@ -177,8 +186,13 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, Status result; for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index[dimension] = i; - TF_RETURN_IF_ERROR( - Equal<NativeT>(expected, actual, multi_index, dimension + 1)); + if (mismatched != nullptr) { + result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1, + mismatched)); + } else { + TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index, + dimension + 1, mismatched)); + } } return result; } @@ -264,10 +278,11 @@ class NearComparator { // within the given error bound. In case of error, the status contains a // detailed message about the discrepancy. static Status Compare(const LiteralSlice& expected, - const LiteralSlice& actual, ErrorSpec error, + const LiteralSlice& actual, + const ShapeIndex& shape_index, ErrorSpec error, bool detailed_message, const MiscompareCallback& miscompare_callback) { - NearComparator<NativeT> comparator(expected, actual, error, + NearComparator<NativeT> comparator(expected, actual, shape_index, error, detailed_message, miscompare_callback); return comparator.Run(); } @@ -300,10 +315,12 @@ class NearComparator { }; NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, - ErrorSpec error, bool detailed_message, + const ShapeIndex& shape_index, ErrorSpec error, + bool detailed_message, const MiscompareCallback& miscompare_callback) : expected_(expected), actual_(actual), + shape_index_(shape_index), error_(error), detailed_message_(detailed_message), miscompare_callback_(miscompare_callback), @@ -329,7 +346,7 @@ class NearComparator { if (num_mismatches_ == 0) { return Status::OK(); } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { - miscompare_callback_(expected_, actual_, mismatches_); + miscompare_callback_(expected_, actual_, mismatches_, shape_index_); } return InvalidArgument("%s", ErrorMessage()); } @@ -595,6 +612,9 @@ class NearComparator { LiteralSlice expected_; LiteralSlice actual_; + // The shape index of the LiteralSlice that is being compared. + ShapeIndex shape_index_; + // The error bounds of the comparison. ErrorSpec error_; @@ -653,70 +673,94 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds; template <typename NativeT> constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; -Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { +Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ShapeIndex& shape_index, + const MiscompareCallback& miscompare_callback) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); - auto index = absl::MakeSpan(multi_index); + Status result; - switch (expected.shape().element_type()) { - case PRED: - result = Equal<bool>(expected, actual, index, 0); - break; - case S8: - result = Equal<int8>(expected, actual, index, 0); - break; - case S16: - result = Equal<int16>(expected, actual, index, 0); - break; - case S32: - result = Equal<int32>(expected, actual, index, 0); - break; - case S64: - result = Equal<int64>(expected, actual, index, 0); - break; - case U8: - result = Equal<uint8>(expected, actual, index, 0); - break; - case U16: - result = Equal<uint16>(expected, actual, index, 0); - break; - case U32: - result = Equal<uint32>(expected, actual, index, 0); - break; - case U64: - result = Equal<uint64>(expected, actual, index, 0); - break; - case BF16: - result = Equal<bfloat16>(expected, actual, index, 0); - break; - case F16: - result = Equal<half>(expected, actual, index, 0); - break; - case F32: - result = Equal<float>(expected, actual, index, 0); - break; - case F64: - result = Equal<double>(expected, actual, index, 0); - break; - case C64: - result = Equal<complex64>(expected, actual, index, 0); - break; - case C128: - result = Equal<complex128>(expected, actual, index, 0); - break; - case TUPLE: { - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - result.Update(EqualHelper(LiteralSlice(expected, {i}), - LiteralSlice(actual, {i}))); + if (expected.shape().IsTuple()) { + ShapeIndex next_index = shape_index; + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + next_index.push_back(i); + Status tuple_result = + EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}), + next_index, miscompare_callback); + if (miscompare_callback) { + result.Update(tuple_result); + } else { + TF_RETURN_IF_ERROR(tuple_result); } - break; + next_index.pop_back(); + } + } else { + std::vector<int64> multi_index(expected.shape().dimensions_size(), 0); + auto index = absl::MakeSpan(multi_index); + + Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED, + expected.shape().dimensions()); + Literal miscompared(unequal_shape); + Literal* miscompared_ptr = + (miscompare_callback == nullptr ? nullptr : &miscompared); + + switch (expected.shape().element_type()) { + case PRED: + result = Equal<bool>(expected, actual, index, 0, miscompared_ptr); + break; + case S8: + result = Equal<int8>(expected, actual, index, 0, miscompared_ptr); + break; + case S16: + result = Equal<int16>(expected, actual, index, 0, miscompared_ptr); + break; + case S32: + result = Equal<int32>(expected, actual, index, 0, miscompared_ptr); + break; + case S64: + result = Equal<int64>(expected, actual, index, 0, miscompared_ptr); + break; + case U8: + result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr); + break; + case U16: + result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr); + break; + case U32: + result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr); + break; + case U64: + result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr); + break; + case BF16: + result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr); + break; + case F16: + result = Equal<half>(expected, actual, index, 0, miscompared_ptr); + break; + case F32: + result = Equal<float>(expected, actual, index, 0, miscompared_ptr); + break; + case F64: + result = Equal<double>(expected, actual, index, 0, miscompared_ptr); + break; + case C64: + result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr); + break; + case C128: + result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr); + break; + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); + default: + LOG(FATAL) << "Unsupported primitive type: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (!result.ok() && miscompare_callback) { + miscompare_callback(expected, actual, LiteralSlice(miscompared), + shape_index); } - case TOKEN: - // Tokens have no on-device representation and are trivially equal. - return Status::OK(); - default: - LOG(FATAL) << "Unsupported primitive type: " - << PrimitiveType_Name(expected.shape().element_type()); } return result; @@ -726,9 +770,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, absl::optional<bool> detailed_message, - const MiscompareCallback& miscompare_callback, - const ShapeIndex& shape_index) { + const ShapeIndex& shape_index, const ErrorSpec& error, + absl::optional<bool> detailed_message, + const MiscompareCallback& miscompare_callback) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); if (expected.shape().IsTuple()) { @@ -739,8 +783,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, ShapeIndex element_index = shape_index; element_index.push_back(i); Status element_result = - NearHelper(expected_element, actual_element, error, detailed_message, - miscompare_callback, element_index); + NearHelper(expected_element, actual_element, element_index, error, + detailed_message, miscompare_callback); if (!element_result.ok()) { element_result = InvalidArgument("Array at shape index %s, %s", element_index.ToString(), @@ -771,28 +815,34 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: - return NearComparator<bfloat16>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + return NearComparator<bfloat16>::Compare(expected, actual, shape_index, + error, use_detailed_message, + miscompare_callback); break; case F16: - return NearComparator<half>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + return NearComparator<half>::Compare(expected, actual, shape_index, + error, use_detailed_message, + miscompare_callback); break; case F32: - return NearComparator<float>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + return NearComparator<float>::Compare(expected, actual, shape_index, + error, use_detailed_message, + miscompare_callback); break; case F64: - return NearComparator<double>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + return NearComparator<double>::Compare(expected, actual, shape_index, + error, use_detailed_message, + miscompare_callback); break; case C64: - return NearComparator<complex64>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + return NearComparator<complex64>::Compare(expected, actual, shape_index, + error, use_detailed_message, + miscompare_callback); break; case C128: return NearComparator<complex128>::Compare( - expected, actual, error, use_detailed_message, miscompare_callback); + expected, actual, shape_index, error, use_detailed_message, + miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -802,7 +852,7 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } // Non-floating point, non-tuple literal. - return EqualHelper(expected, actual); + return EqualHelper(expected, actual, shape_index, miscompare_callback); } } // namespace @@ -878,7 +928,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - Status result = EqualHelper(expected, actual); + Status result = EqualHelper(expected, actual, {}, nullptr); return EmitLiteralsInErrorMessage(result, expected, actual); } @@ -889,9 +939,8 @@ Status Near(const LiteralSlice& expected, const LiteralSlice& actual, XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "Actual literal:"; XLA_VLOG_LINES(1, actual.ToString()); - Status result = - NearHelper(expected, actual, error, detailed_message, miscompare_callback, - /*shape_index=*/{}); + Status result = NearHelper(expected, actual, /*shape_index=*/{}, error, + detailed_message, miscompare_callback); return EmitLiteralsInErrorMessage(result, expected, actual); } diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 23fff3fa348..a8ed74e3704 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -35,9 +35,9 @@ Status EqualShapes(const Shape& expected, const Shape& actual); // primitive type are equal. Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); -using MiscompareCallback = - std::function<void(const LiteralSlice& expected, const LiteralSlice& actual, - const LiteralSlice& mismatches)>; +using MiscompareCallback = std::function<void( + const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches, const ShapeIndex& shape_index)>; // Inspects whether the expected and actual literals are within the given error // bound for all elements. Also, inspects whether the rank, dimensions sizes, @@ -57,6 +57,9 @@ using MiscompareCallback = // If detailed_message is true, then the error message in the assertion result // will contain a more detailed breakdown of mismatches. By default, we display // a detailed message only for "large" inputs. +// +// If miscompare_callback is nullptr, Near will return an error on the first +// detected mismatch. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error, absl::optional<bool> detailed_message, const MiscompareCallback& miscompare_callback); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 7f725a97f28..4dd59cdca5d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -51,7 +51,8 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { // Callback helper that dumps literals to temporary files in the event of a // miscomparison. void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, - const LiteralSlice& mismatches) { + const LiteralSlice& mismatches, + const ShapeIndex& /*shape_index*/) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " << literal_comparison::ToStringTruncated(expected); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "