[XLA] Enable exhaustive tests for sin/cos/tan.

Evaluate tan(fp16) in fp32 precision so it passes the tests.

PiperOrigin-RevId: 247226089
This commit is contained in:
Justin Lebar 2019-05-08 09:18:50 -07:00 committed by TensorFlower Gardener
parent fdd37437ea
commit b17070be2a
2 changed files with 56 additions and 22 deletions

View File

@ -528,7 +528,9 @@ XlaOp Asin(XlaOp x) {
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); } XlaOp Tan(XlaOp x) {
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
}
// Hyperbolic trigonometric functions. // Hyperbolic trigonometric functions.

View File

@ -245,14 +245,6 @@ class ExhaustiveOpTest
int64 begin, end; int64 begin, end;
std::tie(begin, end) = test_range; std::tie(begin, end) = test_range;
if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
LOG(INFO) << absl::StreamFormat(
"Skipping this shard, as the range under test, [%d, %d), falls "
"entirely within the known-incorrect range [%d, %d).",
begin, end, known_incorrect_begin_, known_incorrect_end_);
return;
}
LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
int64 input_size = end - begin; int64 input_size = end - begin;
@ -262,8 +254,7 @@ class ExhaustiveOpTest
IntegralT input_val = i + begin; IntegralT input_val = i + begin;
// If the operation is known to be buggy on a specific input clamp that // If the operation is known to be buggy on a specific input clamp that
// input to 0 under the assumption that the op is at least correct on 0. // input to 0 under the assumption that the op is at least correct on 0.
if (input_val >= known_incorrect_begin_ && if (known_incorrect_fn_ && known_incorrect_fn_(input_val)) {
input_val < known_incorrect_end_) {
input_arr[i] = T{0}; input_arr[i] = T{0};
} else { } else {
input_arr[i] = absl::bit_cast<T>(input_val); input_arr[i] = absl::bit_cast<T>(input_val);
@ -439,6 +430,9 @@ class ExhaustiveOpTest
} }
} }
// Sets error parameters appropriately for testing sin/cos/tan.
void SetParamsForSinCosTan();
// The following members are set during construction so testcases can read // The following members are set during construction so testcases can read
// these values and use them e.g. to influence the values given to the mutable // these values and use them e.g. to influence the values given to the mutable
// members below. // members below.
@ -452,10 +446,9 @@ class ExhaustiveOpTest
// Tests can set the following variables for control over execution. This is // Tests can set the following variables for control over execution. This is
// safe because each XLA_TEST_P instantiates a new instance of this class. // safe because each XLA_TEST_P instantiates a new instance of this class.
// Testing will ignore the given range (encoded as bitwise representations of // Testing will ignore inputs for which known_incorect_fn_ returns true. (Its
// the type under test zero-extended to int64). // argument is the type under test, e.g. f32, zero-extended to int64).
int64 known_incorrect_begin_ = 0; std::function<bool(int64)> known_incorrect_fn_;
int64 known_incorrect_end_ = 0;
// If unset, reasonable defaults will be used depending on the type under // If unset, reasonable defaults will be used depending on the type under
// test. // test.
@ -616,11 +609,45 @@ XLA_TEST_P(ExhaustiveOpTest, Sinh) {
} }
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
void ExhaustiveOpTest::SetParamsForSinCosTan() {
if (platform_ == "Host" || platform_ == "CUDA") {
return;
}
// Non CPU/GPU targets may have used the Cody-Waite range reduction technique
// and will not provide meaningful results for sin/cos/tan if magnitudes
// exceed 2**p.
if (ty_ == F32) {
rel_err_ = 0.001;
abs_err_ = 0.001;
known_incorrect_fn_ = [](int64 v) {
float f = absl::bit_cast<float>(static_cast<uint32>(v));
return std::abs(f) > (1 << 13);
};
} else if (ty_ == BF16) {
known_incorrect_fn_ = [](int64 v) {
float f =
static_cast<float>(absl::bit_cast<bfloat16>(static_cast<uint16>(v)));
return std::abs(f) > (1 << 13);
};
}
}
XLA_TEST_P(ExhaustiveOpTest, Cos) {
SetParamsForSinCosTan();
Run(Cos, std::cos);
}
XLA_TEST_P(ExhaustiveOpTest, Sin) {
SetParamsForSinCosTan();
Run(Sin, std::sin);
}
XLA_TEST_P(ExhaustiveOpTest, Tan) {
SetParamsForSinCosTan();
Run(Tan, std::tan);
}
// TODO(jlebar): Enable these. // TODO(jlebar): Enable these.
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); } // XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
// XLA_TEST_P(ExhaustiveOpTest, Sin) { Run(Sin, std::sin); }
// XLA_TEST_P(ExhaustiveOpTest, Tan) { Run(Tan, std::tan); }
// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); } // XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); }
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
@ -661,19 +688,24 @@ XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) { if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
rel_err_ = 0.001; rel_err_ = 0.001;
} }
float (*host_lgamma)(float) = std::lgamma;
if (platform_ != "Host" && platform_ != "CUDA") { if (platform_ != "Host" && platform_ != "CUDA") {
// TODO(b/123956399): This is a fairly high error, significantly higher than // TODO(b/123956399): This is a fairly high error, significantly higher than
// we see on CPU/GPU. // we see on CPU/GPU.
rel_err_ = 0.01; rel_err_ = 0.01;
abs_err_ = 0.01; abs_err_ = 0.01;
// Overflows for to inf for input 4.08500343e+36 (0x7c44af8e). // Overflows to inf for input 4.08500343e+36 (0x7c44af8e).
if (ty_ == F32) { if (ty_ == F32) {
known_incorrect_begin_ = 0x7c44af8e; host_lgamma = +[](float v) {
known_incorrect_end_ = 0x7c44af8e + 1; if (absl::bit_cast<uint32>(v) == 0x7c44af8e) {
return std::numeric_limits<float>::infinity();
}
return std::lgamma(v);
};
} }
} }
Run(Lgamma, std::lgamma); Run(Lgamma, host_lgamma);
} }
XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); } XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }