[XLA] Enable exhaustive tests for sin/cos/tan.
Evaluate tan(fp16) in fp32 precision so it passes the tests. PiperOrigin-RevId: 247226089
This commit is contained in:
parent
fdd37437ea
commit
b17070be2a
@ -528,7 +528,9 @@ XlaOp Asin(XlaOp x) {
|
||||
|
||||
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
|
||||
|
||||
XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
|
||||
XlaOp Tan(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
|
||||
}
|
||||
|
||||
// Hyperbolic trigonometric functions.
|
||||
|
||||
|
@ -245,14 +245,6 @@ class ExhaustiveOpTest
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = test_range;
|
||||
|
||||
if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
|
||||
LOG(INFO) << absl::StreamFormat(
|
||||
"Skipping this shard, as the range under test, [%d, %d), falls "
|
||||
"entirely within the known-incorrect range [%d, %d).",
|
||||
begin, end, known_incorrect_begin_, known_incorrect_end_);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
|
||||
|
||||
int64 input_size = end - begin;
|
||||
@ -262,8 +254,7 @@ class ExhaustiveOpTest
|
||||
IntegralT input_val = i + begin;
|
||||
// If the operation is known to be buggy on a specific input clamp that
|
||||
// input to 0 under the assumption that the op is at least correct on 0.
|
||||
if (input_val >= known_incorrect_begin_ &&
|
||||
input_val < known_incorrect_end_) {
|
||||
if (known_incorrect_fn_ && known_incorrect_fn_(input_val)) {
|
||||
input_arr[i] = T{0};
|
||||
} else {
|
||||
input_arr[i] = absl::bit_cast<T>(input_val);
|
||||
@ -439,6 +430,9 @@ class ExhaustiveOpTest
|
||||
}
|
||||
}
|
||||
|
||||
// Sets error parameters appropriately for testing sin/cos/tan.
|
||||
void SetParamsForSinCosTan();
|
||||
|
||||
// The following members are set during construction so testcases can read
|
||||
// these values and use them e.g. to influence the values given to the mutable
|
||||
// members below.
|
||||
@ -452,10 +446,9 @@ class ExhaustiveOpTest
|
||||
// Tests can set the following variables for control over execution. This is
|
||||
// safe because each XLA_TEST_P instantiates a new instance of this class.
|
||||
|
||||
// Testing will ignore the given range (encoded as bitwise representations of
|
||||
// the type under test zero-extended to int64).
|
||||
int64 known_incorrect_begin_ = 0;
|
||||
int64 known_incorrect_end_ = 0;
|
||||
// Testing will ignore inputs for which known_incorect_fn_ returns true. (Its
|
||||
// argument is the type under test, e.g. f32, zero-extended to int64).
|
||||
std::function<bool(int64)> known_incorrect_fn_;
|
||||
|
||||
// If unset, reasonable defaults will be used depending on the type under
|
||||
// test.
|
||||
@ -616,11 +609,45 @@ XLA_TEST_P(ExhaustiveOpTest, Sinh) {
|
||||
}
|
||||
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
||||
|
||||
void ExhaustiveOpTest::SetParamsForSinCosTan() {
|
||||
if (platform_ == "Host" || platform_ == "CUDA") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Non CPU/GPU targets may have used the Cody-Waite range reduction technique
|
||||
// and will not provide meaningful results for sin/cos/tan if magnitudes
|
||||
// exceed 2**p.
|
||||
if (ty_ == F32) {
|
||||
rel_err_ = 0.001;
|
||||
abs_err_ = 0.001;
|
||||
known_incorrect_fn_ = [](int64 v) {
|
||||
float f = absl::bit_cast<float>(static_cast<uint32>(v));
|
||||
return std::abs(f) > (1 << 13);
|
||||
};
|
||||
} else if (ty_ == BF16) {
|
||||
known_incorrect_fn_ = [](int64 v) {
|
||||
float f =
|
||||
static_cast<float>(absl::bit_cast<bfloat16>(static_cast<uint16>(v)));
|
||||
return std::abs(f) > (1 << 13);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveOpTest, Cos) {
|
||||
SetParamsForSinCosTan();
|
||||
Run(Cos, std::cos);
|
||||
}
|
||||
XLA_TEST_P(ExhaustiveOpTest, Sin) {
|
||||
SetParamsForSinCosTan();
|
||||
Run(Sin, std::sin);
|
||||
}
|
||||
XLA_TEST_P(ExhaustiveOpTest, Tan) {
|
||||
SetParamsForSinCosTan();
|
||||
Run(Tan, std::tan);
|
||||
}
|
||||
|
||||
// TODO(jlebar): Enable these.
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Sin) { Run(Sin, std::sin); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Tan) { Run(Tan, std::tan); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); }
|
||||
|
||||
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
||||
@ -661,19 +688,24 @@ XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
|
||||
if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
|
||||
rel_err_ = 0.001;
|
||||
}
|
||||
float (*host_lgamma)(float) = std::lgamma;
|
||||
if (platform_ != "Host" && platform_ != "CUDA") {
|
||||
// TODO(b/123956399): This is a fairly high error, significantly higher than
|
||||
// we see on CPU/GPU.
|
||||
rel_err_ = 0.01;
|
||||
abs_err_ = 0.01;
|
||||
|
||||
// Overflows for to inf for input 4.08500343e+36 (0x7c44af8e).
|
||||
// Overflows to inf for input 4.08500343e+36 (0x7c44af8e).
|
||||
if (ty_ == F32) {
|
||||
known_incorrect_begin_ = 0x7c44af8e;
|
||||
known_incorrect_end_ = 0x7c44af8e + 1;
|
||||
host_lgamma = +[](float v) {
|
||||
if (absl::bit_cast<uint32>(v) == 0x7c44af8e) {
|
||||
return std::numeric_limits<float>::infinity();
|
||||
}
|
||||
return std::lgamma(v);
|
||||
};
|
||||
}
|
||||
}
|
||||
Run(Lgamma, std::lgamma);
|
||||
Run(Lgamma, host_lgamma);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
|
||||
|
Loading…
Reference in New Issue
Block a user