Near comparison now works for non-fp types. Callback uses ShapeIndex parameter.

literal_comparison::Near will now make a call to miscompare_callback when non
floating/complex types fail the check.

literal_comparison::Near may now accept nullptr as a miscompare_callback. In
this case, it will return on the first failed comparison.

MiscompareCallback functions now have the ShapeIndex of the LiteralSlice being
compared as a parameter.

PiperOrigin-RevId: 251722837
This commit is contained in:
A. Unique TensorFlower 2019-06-05 14:45:28 -07:00 committed by TensorFlower Gardener
parent b3d37510ad
commit abd1460193
3 changed files with 146 additions and 93 deletions

View File

@ -160,15 +160,24 @@ Status MakeErrorStatus(complex128 lhs, complex128 rhs,
// A recursive function which iterates through every index of expected and
// actual literal and compares their values elementwise. Returns true if all
// elements are equal.
// elements are equal. Mismatched must either be:
// - a literal of booleans that has the same shape as expected and actual. In
// this case, each index in mismatched will be set to true if expected does
// not equal actual at that index and false if there are equal.
// - nullptr. In this case, the function will return once any mismatch is
// found between expected and actual.
template <typename NativeT>
Status Equal(LiteralSlice expected, LiteralSlice actual,
absl::Span<int64> multi_index, int64 dimension) {
absl::Span<int64> multi_index, int64 dimension,
Literal* mismatched = nullptr) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
bool result =
CompareEqual<NativeT>(expected_value, actual_value, multi_index);
if (mismatched) {
mismatched->Set<bool>(multi_index, !result);
}
return result ? Status::OK()
: MakeErrorStatus<NativeT>(expected_value, actual_value,
multi_index);
@ -177,8 +186,13 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
Status result;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
TF_RETURN_IF_ERROR(
Equal<NativeT>(expected, actual, multi_index, dimension + 1));
if (mismatched != nullptr) {
result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
mismatched));
} else {
TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
dimension + 1, mismatched));
}
}
return result;
}
@ -264,10 +278,11 @@ class NearComparator {
// within the given error bound. In case of error, the status contains a
// detailed message about the discrepancy.
static Status Compare(const LiteralSlice& expected,
const LiteralSlice& actual, ErrorSpec error,
const LiteralSlice& actual,
const ShapeIndex& shape_index, ErrorSpec error,
bool detailed_message,
const MiscompareCallback& miscompare_callback) {
NearComparator<NativeT> comparator(expected, actual, error,
NearComparator<NativeT> comparator(expected, actual, shape_index, error,
detailed_message, miscompare_callback);
return comparator.Run();
}
@ -300,10 +315,12 @@ class NearComparator {
};
NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
ErrorSpec error, bool detailed_message,
const ShapeIndex& shape_index, ErrorSpec error,
bool detailed_message,
const MiscompareCallback& miscompare_callback)
: expected_(expected),
actual_(actual),
shape_index_(shape_index),
error_(error),
detailed_message_(detailed_message),
miscompare_callback_(miscompare_callback),
@ -329,7 +346,7 @@ class NearComparator {
if (num_mismatches_ == 0) {
return Status::OK();
} else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
miscompare_callback_(expected_, actual_, mismatches_);
miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
}
return InvalidArgument("%s", ErrorMessage());
}
@ -595,6 +612,9 @@ class NearComparator {
LiteralSlice expected_;
LiteralSlice actual_;
// The shape index of the LiteralSlice that is being compared.
ShapeIndex shape_index_;
// The error bounds of the comparison.
ErrorSpec error_;
@ -653,70 +673,94 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
template <typename NativeT>
constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const ShapeIndex& shape_index,
const MiscompareCallback& miscompare_callback) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
auto index = absl::MakeSpan(multi_index);
Status result;
switch (expected.shape().element_type()) {
case PRED:
result = Equal<bool>(expected, actual, index, 0);
break;
case S8:
result = Equal<int8>(expected, actual, index, 0);
break;
case S16:
result = Equal<int16>(expected, actual, index, 0);
break;
case S32:
result = Equal<int32>(expected, actual, index, 0);
break;
case S64:
result = Equal<int64>(expected, actual, index, 0);
break;
case U8:
result = Equal<uint8>(expected, actual, index, 0);
break;
case U16:
result = Equal<uint16>(expected, actual, index, 0);
break;
case U32:
result = Equal<uint32>(expected, actual, index, 0);
break;
case U64:
result = Equal<uint64>(expected, actual, index, 0);
break;
case BF16:
result = Equal<bfloat16>(expected, actual, index, 0);
break;
case F16:
result = Equal<half>(expected, actual, index, 0);
break;
case F32:
result = Equal<float>(expected, actual, index, 0);
break;
case F64:
result = Equal<double>(expected, actual, index, 0);
break;
case C64:
result = Equal<complex64>(expected, actual, index, 0);
break;
case C128:
result = Equal<complex128>(expected, actual, index, 0);
break;
case TUPLE: {
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
result.Update(EqualHelper(LiteralSlice(expected, {i}),
LiteralSlice(actual, {i})));
if (expected.shape().IsTuple()) {
ShapeIndex next_index = shape_index;
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
next_index.push_back(i);
Status tuple_result =
EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
next_index, miscompare_callback);
if (miscompare_callback) {
result.Update(tuple_result);
} else {
TF_RETURN_IF_ERROR(tuple_result);
}
break;
next_index.pop_back();
}
} else {
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
auto index = absl::MakeSpan(multi_index);
Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
expected.shape().dimensions());
Literal miscompared(unequal_shape);
Literal* miscompared_ptr =
(miscompare_callback == nullptr ? nullptr : &miscompared);
switch (expected.shape().element_type()) {
case PRED:
result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
break;
case S8:
result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
break;
case S16:
result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
break;
case S32:
result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
break;
case S64:
result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
break;
case U8:
result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
break;
case U16:
result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
break;
case U32:
result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
break;
case U64:
result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
break;
case BF16:
result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
break;
case F16:
result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
break;
case F32:
result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
break;
case F64:
result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
break;
case C64:
result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
break;
case C128:
result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
break;
case TOKEN:
// Tokens have no on-device representation and are trivially equal.
return Status::OK();
default:
LOG(FATAL) << "Unsupported primitive type: "
<< PrimitiveType_Name(expected.shape().element_type());
}
if (!result.ok() && miscompare_callback) {
miscompare_callback(expected, actual, LiteralSlice(miscompared),
shape_index);
}
case TOKEN:
// Tokens have no on-device representation and are trivially equal.
return Status::OK();
default:
LOG(FATAL) << "Unsupported primitive type: "
<< PrimitiveType_Name(expected.shape().element_type());
}
return result;
@ -726,9 +770,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback,
const ShapeIndex& shape_index) {
const ShapeIndex& shape_index, const ErrorSpec& error,
absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
if (expected.shape().IsTuple()) {
@ -739,8 +783,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
ShapeIndex element_index = shape_index;
element_index.push_back(i);
Status element_result =
NearHelper(expected_element, actual_element, error, detailed_message,
miscompare_callback, element_index);
NearHelper(expected_element, actual_element, element_index, error,
detailed_message, miscompare_callback);
if (!element_result.ok()) {
element_result = InvalidArgument("Array at shape index %s, %s",
element_index.ToString(),
@ -771,28 +815,34 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
ShapeUtil::ElementsIn(expected.shape()) >= 64);
switch (expected.shape().element_type()) {
case BF16:
return NearComparator<bfloat16>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
error, use_detailed_message,
miscompare_callback);
break;
case F16:
return NearComparator<half>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
return NearComparator<half>::Compare(expected, actual, shape_index,
error, use_detailed_message,
miscompare_callback);
break;
case F32:
return NearComparator<float>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
return NearComparator<float>::Compare(expected, actual, shape_index,
error, use_detailed_message,
miscompare_callback);
break;
case F64:
return NearComparator<double>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
return NearComparator<double>::Compare(expected, actual, shape_index,
error, use_detailed_message,
miscompare_callback);
break;
case C64:
return NearComparator<complex64>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
return NearComparator<complex64>::Compare(expected, actual, shape_index,
error, use_detailed_message,
miscompare_callback);
break;
case C128:
return NearComparator<complex128>::Compare(
expected, actual, error, use_detailed_message, miscompare_callback);
expected, actual, shape_index, error, use_detailed_message,
miscompare_callback);
break;
default:
LOG(FATAL) << "Unsupported primitive type in near comparator: "
@ -802,7 +852,7 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
}
// Non-floating point, non-tuple literal.
return EqualHelper(expected, actual);
return EqualHelper(expected, actual, shape_index, miscompare_callback);
}
} // namespace
@ -878,7 +928,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
Status result = EqualHelper(expected, actual);
Status result = EqualHelper(expected, actual, {}, nullptr);
return EmitLiteralsInErrorMessage(result, expected, actual);
}
@ -889,9 +939,8 @@ Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "Actual literal:";
XLA_VLOG_LINES(1, actual.ToString());
Status result =
NearHelper(expected, actual, error, detailed_message, miscompare_callback,
/*shape_index=*/{});
Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
detailed_message, miscompare_callback);
return EmitLiteralsInErrorMessage(result, expected, actual);
}

View File

@ -35,9 +35,9 @@ Status EqualShapes(const Shape& expected, const Shape& actual);
// primitive type are equal.
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
using MiscompareCallback =
std::function<void(const LiteralSlice& expected, const LiteralSlice& actual,
const LiteralSlice& mismatches)>;
using MiscompareCallback = std::function<void(
const LiteralSlice& expected, const LiteralSlice& actual,
const LiteralSlice& mismatches, const ShapeIndex& shape_index)>;
// Inspects whether the expected and actual literals are within the given error
// bound for all elements. Also, inspects whether the rank, dimensions sizes,
@ -57,6 +57,9 @@ using MiscompareCallback =
// If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches. By default, we display
// a detailed message only for "large" inputs.
//
// If miscompare_callback is nullptr, Near will return an error on the first
// detected mismatch.
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, absl::optional<bool> detailed_message,
const MiscompareCallback& miscompare_callback);

View File

@ -51,7 +51,8 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
// Callback helper that dumps literals to temporary files in the event of a
// miscomparison.
void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
const LiteralSlice& mismatches) {
const LiteralSlice& mismatches,
const ShapeIndex& /*shape_index*/) {
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
<< literal_comparison::ToStringTruncated(expected);
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "