[XLA] Enable exhaustive tests for sin/cos/tan.
Evaluate tan(fp16) in fp32 precision so it passes the tests. PiperOrigin-RevId: 247226089
This commit is contained in:
parent
fdd37437ea
commit
b17070be2a
@ -528,7 +528,9 @@ XlaOp Asin(XlaOp x) {
|
|||||||
|
|
||||||
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
|
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
|
||||||
|
|
||||||
XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
|
XlaOp Tan(XlaOp x) {
|
||||||
|
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
|
||||||
|
}
|
||||||
|
|
||||||
// Hyperbolic trigonometric functions.
|
// Hyperbolic trigonometric functions.
|
||||||
|
|
||||||
|
@ -245,14 +245,6 @@ class ExhaustiveOpTest
|
|||||||
int64 begin, end;
|
int64 begin, end;
|
||||||
std::tie(begin, end) = test_range;
|
std::tie(begin, end) = test_range;
|
||||||
|
|
||||||
if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
|
|
||||||
LOG(INFO) << absl::StreamFormat(
|
|
||||||
"Skipping this shard, as the range under test, [%d, %d), falls "
|
|
||||||
"entirely within the known-incorrect range [%d, %d).",
|
|
||||||
begin, end, known_incorrect_begin_, known_incorrect_end_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
|
LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
|
||||||
|
|
||||||
int64 input_size = end - begin;
|
int64 input_size = end - begin;
|
||||||
@ -262,8 +254,7 @@ class ExhaustiveOpTest
|
|||||||
IntegralT input_val = i + begin;
|
IntegralT input_val = i + begin;
|
||||||
// If the operation is known to be buggy on a specific input clamp that
|
// If the operation is known to be buggy on a specific input clamp that
|
||||||
// input to 0 under the assumption that the op is at least correct on 0.
|
// input to 0 under the assumption that the op is at least correct on 0.
|
||||||
if (input_val >= known_incorrect_begin_ &&
|
if (known_incorrect_fn_ && known_incorrect_fn_(input_val)) {
|
||||||
input_val < known_incorrect_end_) {
|
|
||||||
input_arr[i] = T{0};
|
input_arr[i] = T{0};
|
||||||
} else {
|
} else {
|
||||||
input_arr[i] = absl::bit_cast<T>(input_val);
|
input_arr[i] = absl::bit_cast<T>(input_val);
|
||||||
@ -439,6 +430,9 @@ class ExhaustiveOpTest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sets error parameters appropriately for testing sin/cos/tan.
|
||||||
|
void SetParamsForSinCosTan();
|
||||||
|
|
||||||
// The following members are set during construction so testcases can read
|
// The following members are set during construction so testcases can read
|
||||||
// these values and use them e.g. to influence the values given to the mutable
|
// these values and use them e.g. to influence the values given to the mutable
|
||||||
// members below.
|
// members below.
|
||||||
@ -452,10 +446,9 @@ class ExhaustiveOpTest
|
|||||||
// Tests can set the following variables for control over execution. This is
|
// Tests can set the following variables for control over execution. This is
|
||||||
// safe because each XLA_TEST_P instantiates a new instance of this class.
|
// safe because each XLA_TEST_P instantiates a new instance of this class.
|
||||||
|
|
||||||
// Testing will ignore the given range (encoded as bitwise representations of
|
// Testing will ignore inputs for which known_incorect_fn_ returns true. (Its
|
||||||
// the type under test zero-extended to int64).
|
// argument is the type under test, e.g. f32, zero-extended to int64).
|
||||||
int64 known_incorrect_begin_ = 0;
|
std::function<bool(int64)> known_incorrect_fn_;
|
||||||
int64 known_incorrect_end_ = 0;
|
|
||||||
|
|
||||||
// If unset, reasonable defaults will be used depending on the type under
|
// If unset, reasonable defaults will be used depending on the type under
|
||||||
// test.
|
// test.
|
||||||
@ -616,11 +609,45 @@ XLA_TEST_P(ExhaustiveOpTest, Sinh) {
|
|||||||
}
|
}
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
||||||
|
|
||||||
|
void ExhaustiveOpTest::SetParamsForSinCosTan() {
|
||||||
|
if (platform_ == "Host" || platform_ == "CUDA") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non CPU/GPU targets may have used the Cody-Waite range reduction technique
|
||||||
|
// and will not provide meaningful results for sin/cos/tan if magnitudes
|
||||||
|
// exceed 2**p.
|
||||||
|
if (ty_ == F32) {
|
||||||
|
rel_err_ = 0.001;
|
||||||
|
abs_err_ = 0.001;
|
||||||
|
known_incorrect_fn_ = [](int64 v) {
|
||||||
|
float f = absl::bit_cast<float>(static_cast<uint32>(v));
|
||||||
|
return std::abs(f) > (1 << 13);
|
||||||
|
};
|
||||||
|
} else if (ty_ == BF16) {
|
||||||
|
known_incorrect_fn_ = [](int64 v) {
|
||||||
|
float f =
|
||||||
|
static_cast<float>(absl::bit_cast<bfloat16>(static_cast<uint16>(v)));
|
||||||
|
return std::abs(f) > (1 << 13);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_P(ExhaustiveOpTest, Cos) {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Cos, std::cos);
|
||||||
|
}
|
||||||
|
XLA_TEST_P(ExhaustiveOpTest, Sin) {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Sin, std::sin);
|
||||||
|
}
|
||||||
|
XLA_TEST_P(ExhaustiveOpTest, Tan) {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Tan, std::tan);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(jlebar): Enable these.
|
// TODO(jlebar): Enable these.
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
|
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Sin) { Run(Sin, std::sin); }
|
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Tan) { Run(Tan, std::tan); }
|
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); }
|
// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); }
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
||||||
@ -661,19 +688,24 @@ XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
|
|||||||
if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
|
if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
|
||||||
rel_err_ = 0.001;
|
rel_err_ = 0.001;
|
||||||
}
|
}
|
||||||
|
float (*host_lgamma)(float) = std::lgamma;
|
||||||
if (platform_ != "Host" && platform_ != "CUDA") {
|
if (platform_ != "Host" && platform_ != "CUDA") {
|
||||||
// TODO(b/123956399): This is a fairly high error, significantly higher than
|
// TODO(b/123956399): This is a fairly high error, significantly higher than
|
||||||
// we see on CPU/GPU.
|
// we see on CPU/GPU.
|
||||||
rel_err_ = 0.01;
|
rel_err_ = 0.01;
|
||||||
abs_err_ = 0.01;
|
abs_err_ = 0.01;
|
||||||
|
|
||||||
// Overflows for to inf for input 4.08500343e+36 (0x7c44af8e).
|
// Overflows to inf for input 4.08500343e+36 (0x7c44af8e).
|
||||||
if (ty_ == F32) {
|
if (ty_ == F32) {
|
||||||
known_incorrect_begin_ = 0x7c44af8e;
|
host_lgamma = +[](float v) {
|
||||||
known_incorrect_end_ = 0x7c44af8e + 1;
|
if (absl::bit_cast<uint32>(v) == 0x7c44af8e) {
|
||||||
|
return std::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
return std::lgamma(v);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Run(Lgamma, std::lgamma);
|
Run(Lgamma, host_lgamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
|
XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
|
||||||
|
Loading…
Reference in New Issue
Block a user