[XLA] Be slightly more relaxed about denormals in exhaustive_op_test.
Given subnormal x previously the test would accept any of the following as a "correct" value for f(x). - f(x) evaluated at full precision - f(copysign(0, x)), i.e. flush denormals to zero Now we add a third option: - f(copysign(minfloat, x)), i.e. flush denormals to the closest *normal* value This matches the behavior of XLA:CPU's log implementation when it's run with ftz disabled on the CPU. It also seems like a reasonable behavior, especially for log, where this behavior avoids an infinity, which may avoid a nan down the road. PiperOrigin-RevId: 247766690
This commit is contained in:
parent
7bbd0940e8
commit
1cf2df7f98
@ -338,6 +338,10 @@ class ExhaustiveOpTest
|
|||||||
// denormals.
|
// denormals.
|
||||||
const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
|
const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
|
||||||
const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
|
const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
|
||||||
|
const T expected_at_pos_min_normal_float =
|
||||||
|
static_cast<T>(evaluate_op(std::numeric_limits<float>::min()));
|
||||||
|
const T expected_at_neg_min_normal_float =
|
||||||
|
static_cast<T>(evaluate_op(-std::numeric_limits<float>::min()));
|
||||||
for (int64 i = 0; i < input_arr.size(); ++i) {
|
for (int64 i = 0; i < input_arr.size(); ++i) {
|
||||||
T input = input_arr[i];
|
T input = input_arr[i];
|
||||||
float input_f32 = static_cast<float>(input);
|
float input_f32 = static_cast<float>(input);
|
||||||
@ -369,13 +373,23 @@ class ExhaustiveOpTest
|
|||||||
// - evaluate_op(input)
|
// - evaluate_op(input)
|
||||||
// - evaluate_op(+/-0), where the sign of 0 equal to the sign of
|
// - evaluate_op(+/-0), where the sign of 0 equal to the sign of
|
||||||
// `input`,
|
// `input`,
|
||||||
|
// - evaluate_op(+/-min_normal_float), where the sign of
|
||||||
|
// min_normal_float matches `input`.
|
||||||
// - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
|
// - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
|
||||||
// 0 is the opposite of `input`.
|
// 0 is the opposite of `input`.
|
||||||
|
//
|
||||||
|
// (In particular, the XLA:CPU implementation of log flushes positive
|
||||||
|
// denormals to min-normal-float. This seems kind of reasonable if our
|
||||||
|
// goal is to avoid infinities because they cause nans?)
|
||||||
T sign_preserving_ftz_expected =
|
T sign_preserving_ftz_expected =
|
||||||
std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero;
|
std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero;
|
||||||
|
T flush_to_normal_expected = std::signbit(input_f32)
|
||||||
|
? expected_at_neg_min_normal_float
|
||||||
|
: expected_at_pos_min_normal_float;
|
||||||
T sign_nonpreserving_ftz_expected =
|
T sign_nonpreserving_ftz_expected =
|
||||||
std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero;
|
std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero;
|
||||||
if (IsClose(sign_preserving_ftz_expected, actual) ||
|
if (IsClose(sign_preserving_ftz_expected, actual) ||
|
||||||
|
IsClose(flush_to_normal_expected, actual) ||
|
||||||
(relaxed_denormal_signs_ &&
|
(relaxed_denormal_signs_ &&
|
||||||
IsClose(sign_nonpreserving_ftz_expected, actual))) {
|
IsClose(sign_nonpreserving_ftz_expected, actual))) {
|
||||||
continue;
|
continue;
|
||||||
@ -386,11 +400,13 @@ class ExhaustiveOpTest
|
|||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"Mismatch on denormal value %s. Expected one of:\n"
|
"Mismatch on denormal value %s. Expected one of:\n"
|
||||||
" %10s (evaluated at full-precision value)\n"
|
" %10s (evaluated at full-precision value)\n"
|
||||||
|
" %10s (evaluated at sign-preserving min-normal-float)\n"
|
||||||
" %10s (evaluated after flushing to sign-preserving zero)\n"
|
" %10s (evaluated after flushing to sign-preserving zero)\n"
|
||||||
" %10s (evaluated after flushing to non-sign-preserving "
|
" %10s (evaluated after flushing to non-sign-preserving "
|
||||||
"zero)\n"
|
"zero)\n"
|
||||||
"but got %s.",
|
"but got %s.",
|
||||||
StringifyNum(input), StringifyNum(expected),
|
StringifyNum(input), //
|
||||||
|
StringifyNum(expected), StringifyNum(flush_to_normal_expected),
|
||||||
StringifyNum(sign_preserving_ftz_expected),
|
StringifyNum(sign_preserving_ftz_expected),
|
||||||
StringifyNum(sign_nonpreserving_ftz_expected),
|
StringifyNum(sign_nonpreserving_ftz_expected),
|
||||||
StringifyNum(actual));
|
StringifyNum(actual));
|
||||||
@ -400,10 +416,13 @@ class ExhaustiveOpTest
|
|||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"Mismatch on denormal value %s. Expected one of:\n"
|
"Mismatch on denormal value %s. Expected one of:\n"
|
||||||
" %10s (evaluated at full-precision value)\n"
|
" %10s (evaluated at full-precision value)\n"
|
||||||
|
" %10s (evaluated at sign-preserving min-normal-float)\n"
|
||||||
" %10s (evaluated after flushing to sign-preserving zero)\n"
|
" %10s (evaluated after flushing to sign-preserving zero)\n"
|
||||||
"but got %s.",
|
"but got %s.",
|
||||||
StringifyNum(input), StringifyNum(expected),
|
StringifyNum(input), //
|
||||||
StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual));
|
StringifyNum(expected), StringifyNum(flush_to_normal_expected),
|
||||||
|
StringifyNum(sign_preserving_ftz_expected), //
|
||||||
|
StringifyNum(actual));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user