diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 07ae95063c0..9b3de75dd4e 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -736,7 +736,7 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback, const ShapeIndex& shape_index) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); @@ -777,30 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { + bool use_detailed_message = detailed_message.value_or( + ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F32: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C128: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -891,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback) { VLOG(1) << "Expected literal:"; XLA_VLOG_LINES(1, expected.ToString()); diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 9e5bf7c1d06..23fff3fa348 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -55,9 +55,10 @@ using MiscompareCallback = // being compared. // // If detailed_message is true, then the error message in the assertion result -// will contain a more detailed breakdown of mismatches. +// will contain a more detailed breakdown of mismatches. By default, we display +// a detailed message only for "large" inputs. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback); // Calling ToString on a literal with over 100 million elements takes around diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 554eb24d441..a2fd6070731 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error_spec, bool detailed_message) { + const ErrorSpec& error_spec, absl::optional detailed_message) { return StatusToAssertion(literal_comparison::Near( expected, actual, error_spec, detailed_message, &OnMiscompare)); } @@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( - expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); + expected, actual, *error, /*detailed_message=*/absl::nullopt, + &OnMiscompare)); } VLOG(1) << "Expects equal"; return StatusToAssertion(literal_comparison::Equal(expected, actual)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 43cca91f64b..d7cf9bed98a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -93,7 +93,7 @@ class LiteralTestUtil { static ::testing::AssertionResult Near( const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error_spec, - bool detailed_message = false) TF_MUST_USE_RESULT; + absl::optional detailed_message = absl::nullopt) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values.