[XLA] By default, display a detailed breakdown when comparing large literals.

Literal comparison has the option of displaying a "detailed breakdown", which
includes things like the distribution of errors.  Previously, you'd only get
this breakdown if you asked for it, but there are many routines that don't
expose this knob.  We could expose it piecemeal, but I think in general if
you're comparing a "large" literal and you have mismatches, you probably want
this breakdown.  At least, it's not going to be a lot of noise.  So this patch
makes us show the default breakdown by default when the array being compared is
above a certain size.

PiperOrigin-RevId: 230834776
This commit is contained in:
Justin Lebar 2019-01-24 19:25:06 -08:00 committed by TensorFlower Gardener
parent c00907e080
commit a827a4bee1
4 changed files with 17 additions and 13 deletions

View File

@ -736,7 +736,7 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback,
const ShapeIndex& shape_index) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
@ -777,30 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) {
bool use_detailed_message = detailed_message.value_or(
ShapeUtil::ElementsIn(expected.shape()) >= 64);
switch (expected.shape().element_type()) {
case BF16:
return NearComparator<bfloat16>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
case F16:
return NearComparator<half>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
case F32:
return NearComparator<float>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
case F64:
return NearComparator<double>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
case C64:
return NearComparator<complex64>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
case C128:
return NearComparator<complex128>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
expected, actual, error, use_detailed_message, miscompare_callback);
break;
default:
LOG(FATAL) << "Unsupported primitive type in near comparator: "
@ -891,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback) {
VLOG(1) << "Expected literal:";
XLA_VLOG_LINES(1, expected.ToString());

View File

@ -55,9 +55,10 @@ using MiscompareCallback =
// being compared.
//
// If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches.
// will contain a more detailed breakdown of mismatches. By default, we display
// a detailed message only for "large" inputs.
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback);
// Calling ToString on a literal with over 100 million elements takes around

View File

@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error_spec, bool detailed_message) {
const ErrorSpec& error_spec, absl::optional<bool> detailed_message) {
return StatusToAssertion(literal_comparison::Near(
expected, actual, error_spec, detailed_message, &OnMiscompare));
}
@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
if (error.has_value()) {
VLOG(1) << "Expects near";
return StatusToAssertion(literal_comparison::Near(
expected, actual, *error, /*detailed_message=*/false, &OnMiscompare));
expected, actual, *error, /*detailed_message=*/absl::nullopt,
&OnMiscompare));
}
VLOG(1) << "Expects equal";
return StatusToAssertion(literal_comparison::Equal(expected, actual));

View File

@ -93,7 +93,7 @@ class LiteralTestUtil {
static ::testing::AssertionResult Near(
const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error_spec,
bool detailed_message = false) TF_MUST_USE_RESULT;
absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT;
// Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values.