diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index adb1d395687..7df01e04c6d 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -338,6 +338,10 @@ class ExhaustiveOpTest // denormals. const T expected_at_pos_zero = static_cast(evaluate_op(0)); const T expected_at_neg_zero = static_cast(evaluate_op(-0.0)); + const T expected_at_pos_min_normal_float = + static_cast(evaluate_op(std::numeric_limits::min())); + const T expected_at_neg_min_normal_float = + static_cast(evaluate_op(-std::numeric_limits::min())); for (int64 i = 0; i < input_arr.size(); ++i) { T input = input_arr[i]; float input_f32 = static_cast(input); @@ -369,13 +373,23 @@ class ExhaustiveOpTest // - evaluate_op(input) // - evaluate_op(+/-0), where the sign of 0 equal to the sign of // `input`, + // - evaluate_op(+/-min_normal_float), where the sign of + // min_normal_float matches `input`. // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of // 0 is the opposite of `input`. + // + // (In particular, the XLA:CPU implementation of log flushes positive + // denormals to min-normal-float. This seems kind of reasonable if our + // goal is to avoid infinities because they cause nans?) T sign_preserving_ftz_expected = std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero; + T flush_to_normal_expected = std::signbit(input_f32) + ? expected_at_neg_min_normal_float + : expected_at_pos_min_normal_float; T sign_nonpreserving_ftz_expected = std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero; if (IsClose(sign_preserving_ftz_expected, actual) || + IsClose(flush_to_normal_expected, actual) || (relaxed_denormal_signs_ && IsClose(sign_nonpreserving_ftz_expected, actual))) { continue; @@ -386,11 +400,13 @@ class ExhaustiveOpTest return absl::StrFormat( "Mismatch on denormal value %s. Expected one of:\n" " %10s (evaluated at full-precision value)\n" + " %10s (evaluated at sign-preserving min-normal-float)\n" " %10s (evaluated after flushing to sign-preserving zero)\n" " %10s (evaluated after flushing to non-sign-preserving " "zero)\n" "but got %s.", - StringifyNum(input), StringifyNum(expected), + StringifyNum(input), // + StringifyNum(expected), StringifyNum(flush_to_normal_expected), StringifyNum(sign_preserving_ftz_expected), StringifyNum(sign_nonpreserving_ftz_expected), StringifyNum(actual)); @@ -400,10 +416,13 @@ class ExhaustiveOpTest return absl::StrFormat( "Mismatch on denormal value %s. Expected one of:\n" " %10s (evaluated at full-precision value)\n" + " %10s (evaluated at sign-preserving min-normal-float)\n" " %10s (evaluated after flushing to sign-preserving zero)\n" "but got %s.", - StringifyNum(input), StringifyNum(expected), - StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual)); + StringifyNum(input), // + StringifyNum(expected), StringifyNum(flush_to_normal_expected), + StringifyNum(sign_preserving_ftz_expected), // + StringifyNum(actual)); }); } }