[XLA] By default, display a detailed breakdown when comparing large literals.

Literal comparison has the option of displaying a "detailed breakdown", which
includes things like the distribution of errors.  Previously, you'd only get
this breakdown if you asked for it, but there are many routines that don't
expose this knob.  We could expose it piecemeal, but I think in general if
you're comparing a "large" literal and you have mismatches, you probably want
this breakdown.  At least, it's not going to be a lot of noise.  So this patch
makes us show the default breakdown by default when the array being compared is
above a certain size.

PiperOrigin-RevId: 230834776
This commit is contained in:
Justin Lebar 2019-01-24 19:25:06 -08:00 committed by TensorFlower Gardener
parent c00907e080
commit a827a4bee1
4 changed files with 17 additions and 13 deletions

View File

@ -736,7 +736,7 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
// via recursion. shape_index is the ShapeIndex of expected (or actual) // via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared. // currently being compared.
Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message, const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback, const MiscompareCallback& miscompare_callback,
const ShapeIndex& shape_index) { const ShapeIndex& shape_index) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
@ -777,30 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
if (ShapeUtil::ElementIsFloating(expected.shape()) || if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) { ShapeUtil::ElementIsComplex(expected.shape())) {
bool use_detailed_message = detailed_message.value_or(
ShapeUtil::ElementsIn(expected.shape()) >= 64);
switch (expected.shape().element_type()) { switch (expected.shape().element_type()) {
case BF16: case BF16:
return NearComparator<bfloat16>::Compare( return NearComparator<bfloat16>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
case F16: case F16:
return NearComparator<half>::Compare( return NearComparator<half>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
case F32: case F32:
return NearComparator<float>::Compare( return NearComparator<float>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
case F64: case F64:
return NearComparator<double>::Compare( return NearComparator<double>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
case C64: case C64:
return NearComparator<complex64>::Compare( return NearComparator<complex64>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
case C128: case C128:
return NearComparator<complex128>::Compare( return NearComparator<complex128>::Compare(
expected, actual, error, detailed_message, miscompare_callback); expected, actual, error, use_detailed_message, miscompare_callback);
break; break;
default: default:
LOG(FATAL) << "Unsupported primitive type in near comparator: " LOG(FATAL) << "Unsupported primitive type in near comparator: "
@ -891,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
} }
Status Near(const LiteralSlice& expected, const LiteralSlice& actual, Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message, const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback) { const MiscompareCallback& miscompare_callback) {
VLOG(1) << "Expected literal:"; VLOG(1) << "Expected literal:";
XLA_VLOG_LINES(1, expected.ToString()); XLA_VLOG_LINES(1, expected.ToString());

View File

@ -55,9 +55,10 @@ using MiscompareCallback =
// being compared. // being compared.
// //
// If detailed_message is true, then the error message in the assertion result // If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches. // will contain a more detailed breakdown of mismatches. By default, we display
// a detailed message only for "large" inputs.
Status Near(const LiteralSlice& expected, const LiteralSlice& actual, Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message, const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback); const MiscompareCallback& miscompare_callback);
// Calling ToString on a literal with over 100 million elements takes around // Calling ToString on a literal with over 100 million elements takes around

View File

@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
/* static */ ::testing::AssertionResult LiteralTestUtil::Near( /* static */ ::testing::AssertionResult LiteralTestUtil::Near(
const LiteralSlice& expected, const LiteralSlice& actual, const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error_spec, bool detailed_message) { const ErrorSpec& error_spec, absl::optional<bool> detailed_message) {
return StatusToAssertion(literal_comparison::Near( return StatusToAssertion(literal_comparison::Near(
expected, actual, error_spec, detailed_message, &OnMiscompare)); expected, actual, error_spec, detailed_message, &OnMiscompare));
} }
@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
if (error.has_value()) { if (error.has_value()) {
VLOG(1) << "Expects near"; VLOG(1) << "Expects near";
return StatusToAssertion(literal_comparison::Near( return StatusToAssertion(literal_comparison::Near(
expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); expected, actual, *error, /*detailed_message=*/absl::nullopt,
&OnMiscompare));
} }
VLOG(1) << "Expects equal"; VLOG(1) << "Expects equal";
return StatusToAssertion(literal_comparison::Equal(expected, actual)); return StatusToAssertion(literal_comparison::Equal(expected, actual));

View File

@ -93,7 +93,7 @@ class LiteralTestUtil {
static ::testing::AssertionResult Near( static ::testing::AssertionResult Near(
const LiteralSlice& expected, const LiteralSlice& actual, const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error_spec, const ErrorSpec& error_spec,
bool detailed_message = false) TF_MUST_USE_RESULT; absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT;
// Asserts the given literal are within the given error bound of the given // Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values. // expected values. Only supported for floating point values.