Near comparison now works for non-fp types. Callback uses ShapeIndex parameter.
literal_comparison::Near will now make a call to miscompare_callback when non floating/complex types fail the check. literal_comparison::Near may now accept nullptr as a miscompare_callback. In this case, it will return on the first failed comparison. MiscompareCallback functions now have the ShapeIndex of the LiteralSlice being compared as a parameter. PiperOrigin-RevId: 251722837
This commit is contained in:
parent
b3d37510ad
commit
abd1460193
@ -160,15 +160,24 @@ Status MakeErrorStatus(complex128 lhs, complex128 rhs,
|
||||
|
||||
// A recursive function which iterates through every index of expected and
|
||||
// actual literal and compares their values elementwise. Returns true if all
|
||||
// elements are equal.
|
||||
// elements are equal. Mismatched must either be:
|
||||
// - a literal of booleans that has the same shape as expected and actual. In
|
||||
// this case, each index in mismatched will be set to true if expected does
|
||||
// not equal actual at that index and false if there are equal.
|
||||
// - nullptr. In this case, the function will return once any mismatch is
|
||||
// found between expected and actual.
|
||||
template <typename NativeT>
|
||||
Status Equal(LiteralSlice expected, LiteralSlice actual,
|
||||
absl::Span<int64> multi_index, int64 dimension) {
|
||||
absl::Span<int64> multi_index, int64 dimension,
|
||||
Literal* mismatched = nullptr) {
|
||||
if (dimension == expected.shape().dimensions_size()) {
|
||||
NativeT expected_value = expected.Get<NativeT>(multi_index);
|
||||
NativeT actual_value = actual.Get<NativeT>(multi_index);
|
||||
bool result =
|
||||
CompareEqual<NativeT>(expected_value, actual_value, multi_index);
|
||||
if (mismatched) {
|
||||
mismatched->Set<bool>(multi_index, !result);
|
||||
}
|
||||
return result ? Status::OK()
|
||||
: MakeErrorStatus<NativeT>(expected_value, actual_value,
|
||||
multi_index);
|
||||
@ -177,8 +186,13 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
|
||||
Status result;
|
||||
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
|
||||
multi_index[dimension] = i;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Equal<NativeT>(expected, actual, multi_index, dimension + 1));
|
||||
if (mismatched != nullptr) {
|
||||
result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
|
||||
mismatched));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
|
||||
dimension + 1, mismatched));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -264,10 +278,11 @@ class NearComparator {
|
||||
// within the given error bound. In case of error, the status contains a
|
||||
// detailed message about the discrepancy.
|
||||
static Status Compare(const LiteralSlice& expected,
|
||||
const LiteralSlice& actual, ErrorSpec error,
|
||||
const LiteralSlice& actual,
|
||||
const ShapeIndex& shape_index, ErrorSpec error,
|
||||
bool detailed_message,
|
||||
const MiscompareCallback& miscompare_callback) {
|
||||
NearComparator<NativeT> comparator(expected, actual, error,
|
||||
NearComparator<NativeT> comparator(expected, actual, shape_index, error,
|
||||
detailed_message, miscompare_callback);
|
||||
return comparator.Run();
|
||||
}
|
||||
@ -300,10 +315,12 @@ class NearComparator {
|
||||
};
|
||||
|
||||
NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
ErrorSpec error, bool detailed_message,
|
||||
const ShapeIndex& shape_index, ErrorSpec error,
|
||||
bool detailed_message,
|
||||
const MiscompareCallback& miscompare_callback)
|
||||
: expected_(expected),
|
||||
actual_(actual),
|
||||
shape_index_(shape_index),
|
||||
error_(error),
|
||||
detailed_message_(detailed_message),
|
||||
miscompare_callback_(miscompare_callback),
|
||||
@ -329,7 +346,7 @@ class NearComparator {
|
||||
if (num_mismatches_ == 0) {
|
||||
return Status::OK();
|
||||
} else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
|
||||
miscompare_callback_(expected_, actual_, mismatches_);
|
||||
miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
|
||||
}
|
||||
return InvalidArgument("%s", ErrorMessage());
|
||||
}
|
||||
@ -595,6 +612,9 @@ class NearComparator {
|
||||
LiteralSlice expected_;
|
||||
LiteralSlice actual_;
|
||||
|
||||
// The shape index of the LiteralSlice that is being compared.
|
||||
ShapeIndex shape_index_;
|
||||
|
||||
// The error bounds of the comparison.
|
||||
ErrorSpec error_;
|
||||
|
||||
@ -653,70 +673,94 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
|
||||
template <typename NativeT>
|
||||
constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
|
||||
|
||||
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const ShapeIndex& shape_index,
|
||||
const MiscompareCallback& miscompare_callback) {
|
||||
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
|
||||
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
|
||||
auto index = absl::MakeSpan(multi_index);
|
||||
|
||||
Status result;
|
||||
switch (expected.shape().element_type()) {
|
||||
case PRED:
|
||||
result = Equal<bool>(expected, actual, index, 0);
|
||||
break;
|
||||
case S8:
|
||||
result = Equal<int8>(expected, actual, index, 0);
|
||||
break;
|
||||
case S16:
|
||||
result = Equal<int16>(expected, actual, index, 0);
|
||||
break;
|
||||
case S32:
|
||||
result = Equal<int32>(expected, actual, index, 0);
|
||||
break;
|
||||
case S64:
|
||||
result = Equal<int64>(expected, actual, index, 0);
|
||||
break;
|
||||
case U8:
|
||||
result = Equal<uint8>(expected, actual, index, 0);
|
||||
break;
|
||||
case U16:
|
||||
result = Equal<uint16>(expected, actual, index, 0);
|
||||
break;
|
||||
case U32:
|
||||
result = Equal<uint32>(expected, actual, index, 0);
|
||||
break;
|
||||
case U64:
|
||||
result = Equal<uint64>(expected, actual, index, 0);
|
||||
break;
|
||||
case BF16:
|
||||
result = Equal<bfloat16>(expected, actual, index, 0);
|
||||
break;
|
||||
case F16:
|
||||
result = Equal<half>(expected, actual, index, 0);
|
||||
break;
|
||||
case F32:
|
||||
result = Equal<float>(expected, actual, index, 0);
|
||||
break;
|
||||
case F64:
|
||||
result = Equal<double>(expected, actual, index, 0);
|
||||
break;
|
||||
case C64:
|
||||
result = Equal<complex64>(expected, actual, index, 0);
|
||||
break;
|
||||
case C128:
|
||||
result = Equal<complex128>(expected, actual, index, 0);
|
||||
break;
|
||||
case TUPLE: {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
|
||||
result.Update(EqualHelper(LiteralSlice(expected, {i}),
|
||||
LiteralSlice(actual, {i})));
|
||||
if (expected.shape().IsTuple()) {
|
||||
ShapeIndex next_index = shape_index;
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
|
||||
next_index.push_back(i);
|
||||
Status tuple_result =
|
||||
EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
|
||||
next_index, miscompare_callback);
|
||||
if (miscompare_callback) {
|
||||
result.Update(tuple_result);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(tuple_result);
|
||||
}
|
||||
break;
|
||||
next_index.pop_back();
|
||||
}
|
||||
} else {
|
||||
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
|
||||
auto index = absl::MakeSpan(multi_index);
|
||||
|
||||
Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
|
||||
expected.shape().dimensions());
|
||||
Literal miscompared(unequal_shape);
|
||||
Literal* miscompared_ptr =
|
||||
(miscompare_callback == nullptr ? nullptr : &miscompared);
|
||||
|
||||
switch (expected.shape().element_type()) {
|
||||
case PRED:
|
||||
result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case S8:
|
||||
result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case S16:
|
||||
result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case S32:
|
||||
result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case S64:
|
||||
result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case U8:
|
||||
result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case U16:
|
||||
result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case U32:
|
||||
result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case U64:
|
||||
result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case BF16:
|
||||
result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case F16:
|
||||
result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case F32:
|
||||
result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case F64:
|
||||
result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case C64:
|
||||
result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case C128:
|
||||
result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
|
||||
break;
|
||||
case TOKEN:
|
||||
// Tokens have no on-device representation and are trivially equal.
|
||||
return Status::OK();
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported primitive type: "
|
||||
<< PrimitiveType_Name(expected.shape().element_type());
|
||||
}
|
||||
|
||||
if (!result.ok() && miscompare_callback) {
|
||||
miscompare_callback(expected, actual, LiteralSlice(miscompared),
|
||||
shape_index);
|
||||
}
|
||||
case TOKEN:
|
||||
// Tokens have no on-device representation and are trivially equal.
|
||||
return Status::OK();
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported primitive type: "
|
||||
<< PrimitiveType_Name(expected.shape().element_type());
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -726,9 +770,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||
// via recursion. shape_index is the ShapeIndex of expected (or actual)
|
||||
// currently being compared.
|
||||
Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const ErrorSpec& error, absl::optional<bool> detailed_message,
|
||||
const MiscompareCallback& miscompare_callback,
|
||||
const ShapeIndex& shape_index) {
|
||||
const ShapeIndex& shape_index, const ErrorSpec& error,
|
||||
absl::optional<bool> detailed_message,
|
||||
const MiscompareCallback& miscompare_callback) {
|
||||
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
|
||||
|
||||
if (expected.shape().IsTuple()) {
|
||||
@ -739,8 +783,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
ShapeIndex element_index = shape_index;
|
||||
element_index.push_back(i);
|
||||
Status element_result =
|
||||
NearHelper(expected_element, actual_element, error, detailed_message,
|
||||
miscompare_callback, element_index);
|
||||
NearHelper(expected_element, actual_element, element_index, error,
|
||||
detailed_message, miscompare_callback);
|
||||
if (!element_result.ok()) {
|
||||
element_result = InvalidArgument("Array at shape index %s, %s",
|
||||
element_index.ToString(),
|
||||
@ -771,28 +815,34 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
ShapeUtil::ElementsIn(expected.shape()) >= 64);
|
||||
switch (expected.shape().element_type()) {
|
||||
case BF16:
|
||||
return NearComparator<bfloat16>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
|
||||
error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
case F16:
|
||||
return NearComparator<half>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
return NearComparator<half>::Compare(expected, actual, shape_index,
|
||||
error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
case F32:
|
||||
return NearComparator<float>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
return NearComparator<float>::Compare(expected, actual, shape_index,
|
||||
error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
case F64:
|
||||
return NearComparator<double>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
return NearComparator<double>::Compare(expected, actual, shape_index,
|
||||
error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
case C64:
|
||||
return NearComparator<complex64>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
return NearComparator<complex64>::Compare(expected, actual, shape_index,
|
||||
error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
case C128:
|
||||
return NearComparator<complex128>::Compare(
|
||||
expected, actual, error, use_detailed_message, miscompare_callback);
|
||||
expected, actual, shape_index, error, use_detailed_message,
|
||||
miscompare_callback);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported primitive type in near comparator: "
|
||||
@ -802,7 +852,7 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
}
|
||||
|
||||
// Non-floating point, non-tuple literal.
|
||||
return EqualHelper(expected, actual);
|
||||
return EqualHelper(expected, actual, shape_index, miscompare_callback);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -878,7 +928,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||
XLA_VLOG_LINES(1, expected.ToString());
|
||||
VLOG(1) << "actual:";
|
||||
XLA_VLOG_LINES(1, actual.ToString());
|
||||
Status result = EqualHelper(expected, actual);
|
||||
Status result = EqualHelper(expected, actual, {}, nullptr);
|
||||
return EmitLiteralsInErrorMessage(result, expected, actual);
|
||||
}
|
||||
|
||||
@ -889,9 +939,8 @@ Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
XLA_VLOG_LINES(1, expected.ToString());
|
||||
VLOG(1) << "Actual literal:";
|
||||
XLA_VLOG_LINES(1, actual.ToString());
|
||||
Status result =
|
||||
NearHelper(expected, actual, error, detailed_message, miscompare_callback,
|
||||
/*shape_index=*/{});
|
||||
Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
|
||||
detailed_message, miscompare_callback);
|
||||
return EmitLiteralsInErrorMessage(result, expected, actual);
|
||||
}
|
||||
|
||||
|
@ -35,9 +35,9 @@ Status EqualShapes(const Shape& expected, const Shape& actual);
|
||||
// primitive type are equal.
|
||||
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
|
||||
|
||||
using MiscompareCallback =
|
||||
std::function<void(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const LiteralSlice& mismatches)>;
|
||||
using MiscompareCallback = std::function<void(
|
||||
const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const LiteralSlice& mismatches, const ShapeIndex& shape_index)>;
|
||||
|
||||
// Inspects whether the expected and actual literals are within the given error
|
||||
// bound for all elements. Also, inspects whether the rank, dimensions sizes,
|
||||
@ -57,6 +57,9 @@ using MiscompareCallback =
|
||||
// If detailed_message is true, then the error message in the assertion result
|
||||
// will contain a more detailed breakdown of mismatches. By default, we display
|
||||
// a detailed message only for "large" inputs.
|
||||
//
|
||||
// If miscompare_callback is nullptr, Near will return an error on the first
|
||||
// detected mismatch.
|
||||
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const ErrorSpec& error, absl::optional<bool> detailed_message,
|
||||
const MiscompareCallback& miscompare_callback);
|
||||
|
@ -51,7 +51,8 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
|
||||
// Callback helper that dumps literals to temporary files in the event of a
|
||||
// miscomparison.
|
||||
void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const LiteralSlice& mismatches) {
|
||||
const LiteralSlice& mismatches,
|
||||
const ShapeIndex& /*shape_index*/) {
|
||||
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
|
||||
<< literal_comparison::ToStringTruncated(expected);
|
||||
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "
|
||||
|
Loading…
Reference in New Issue
Block a user