[XLA] Add exhaustive binary tests for F32 and F64.
PiperOrigin-RevId: 261816030
This commit is contained in:
parent
d30d01dba7
commit
85a9058eff
@ -829,6 +829,46 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f32",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F32"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f64",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F64"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "reduce_precision_test",
|
||||
srcs = ["reduce_precision_test.cc"],
|
||||
|
@ -174,5 +174,219 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Exhaustive test for binary operations for float and double.
|
||||
//
|
||||
// Test parameter is a tuple of (FpValues, FpValues) describing the possible
|
||||
// values for each operand. The inputs for the test are the Cartesian product
|
||||
// of the possible values for the two operands.
|
||||
template <PrimitiveType T>
|
||||
class Exhaustive32BitOrMoreBinaryTest
|
||||
: public ExhaustiveBinaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
|
||||
protected:
|
||||
using typename ExhaustiveBinaryTest<T>::NativeT;
|
||||
using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
|
||||
|
||||
private:
|
||||
int64 GetInputSize() override {
|
||||
FpValues values_0;
|
||||
FpValues values_1;
|
||||
std::tie(values_0, values_1) = GetParam();
|
||||
return values_0.GetTotalNumValues() * values_1.GetTotalNumValues();
|
||||
}
|
||||
|
||||
void FillInput(std::array<Literal, 2>* input_literals) override {
|
||||
int64 input_size = GetInputSize();
|
||||
FpValues values_0;
|
||||
FpValues values_1;
|
||||
std::tie(values_0, values_1) = GetParam();
|
||||
|
||||
VLOG(2) << " testing " << values_0.ToString() << " " << values_1.ToString()
|
||||
<< "total values " << input_size;
|
||||
CHECK(input_size == (*input_literals)[0].element_count() &&
|
||||
input_size == (*input_literals)[1].element_count());
|
||||
|
||||
absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
|
||||
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
|
||||
|
||||
uint64 i = 0;
|
||||
for (auto src0 : values_0) {
|
||||
for (auto src1 : values_1) {
|
||||
input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(src0, 1);
|
||||
input_arr_1[i] = ConvertAndReplaceKnownIncorrectValueWith(src1, 1);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
CHECK_EQ(i, input_size);
|
||||
}
|
||||
};
|
||||
|
||||
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
|
||||
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
|
||||
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Add) {
|
||||
auto host_add = [](float x, float y) { return x + y; };
|
||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) {
|
||||
auto host_sub = [](float x, float y) { return x - y; };
|
||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||
}
|
||||
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) {
|
||||
auto host_mul = [](float x, float y) { return x * y; };
|
||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||
}
|
||||
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) {
|
||||
auto host_div = [](float x, float y) { return x / y; };
|
||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Max) {
|
||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Min) {
|
||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
||||
}
|
||||
|
||||
// It is more convenient to implement Abs(complex) as a binary op than a unary
|
||||
// op, as the operations we currently support all have the same data type for
|
||||
// the source operands and the results.
|
||||
// TODO(bixia): May want to move this test to unary test if we will be able to
|
||||
// implement Abs(complex) as unary conveniently.
|
||||
//
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
|
||||
auto host_abs_complex = [](float x, float y) {
|
||||
return std::abs(std::complex<float>(x, y));
|
||||
};
|
||||
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
||||
|
||||
Run(device_abs_complex, host_abs_complex);
|
||||
}
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F32)
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveF32BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialAndNormalValues, ExhaustiveF32BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::Values(GetNormals<float>(2000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndSpecialValues, ExhaustiveF32BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<float>(2000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndNormalValues, ExhaustiveF32BinaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<float>(2000)),
|
||||
::testing::Values(GetNormals<float>(2000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test.
|
||||
// Comparing with the unary tests, the binary tests use a smaller set of inputs
|
||||
// for each sub-test to avoid timeout because the implementation of ExpectNear
|
||||
// more than 2x slower for binary test.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnituedNormalValues, ExhaustiveF32BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
|
||||
2000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
|
||||
|
||||
#endif
|
||||
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Add) {
|
||||
auto host_add = [](double x, double y) { return x + y; };
|
||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) {
|
||||
auto host_sub = [](double x, double y) { return x - y; };
|
||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||
}
|
||||
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) {
|
||||
auto host_mul = [](double x, double y) { return x * y; };
|
||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||
}
|
||||
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) {
|
||||
auto host_div = [](double x, double y) { return x / y; };
|
||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Max) {
|
||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Min) {
|
||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
|
||||
}
|
||||
|
||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
|
||||
auto host_abs_complex = [](double x, double y) {
|
||||
return std::abs(std::complex<double>(x, y));
|
||||
};
|
||||
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
||||
|
||||
Run(device_abs_complex, host_abs_complex);
|
||||
}
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F64)
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialAndNormalValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::Values(GetNormals<double>(1000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndSpecialValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<double>(1000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndNormalValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<double>(1000)),
|
||||
::testing::Values(GetNormals<double>(1000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 1000 ^ 2 inputs in each sub-test.
|
||||
// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each
|
||||
// for each sub-test comparing with the unary test to avoid timeout.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnituedNormalValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||
#endif
|
||||
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -215,6 +215,18 @@ inline ExhaustiveOpTestBase<BF16, 1>::ErrorSpec DefaultSpecGenerator<BF16, 1>(
|
||||
return ExhaustiveOpTestBase<BF16, 1>::ErrorSpec{0.002, 0.02};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ExhaustiveOpTestBase<F64, 2>::ErrorSpec DefaultSpecGenerator<F64, 2>(
|
||||
double, double) {
|
||||
return ExhaustiveOpTestBase<F64, 2>::ErrorSpec{0.001, 0.001};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ExhaustiveOpTestBase<F32, 2>::ErrorSpec DefaultSpecGenerator<F32, 2>(
|
||||
float, float) {
|
||||
return ExhaustiveOpTestBase<F32, 2>::ErrorSpec{0.001, 0.001};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ExhaustiveOpTestBase<F16, 2>::ErrorSpec DefaultSpecGenerator<F16, 2>(
|
||||
Eigen::half, Eigen::half) {
|
||||
@ -242,6 +254,8 @@ template class ExhaustiveOpTestBase<F32, 1>;
|
||||
template class ExhaustiveOpTestBase<F16, 1>;
|
||||
template class ExhaustiveOpTestBase<BF16, 1>;
|
||||
|
||||
template class ExhaustiveOpTestBase<F64, 2>;
|
||||
template class ExhaustiveOpTestBase<F32, 2>;
|
||||
template class ExhaustiveOpTestBase<F16, 2>;
|
||||
template class ExhaustiveOpTestBase<BF16, 2>;
|
||||
|
||||
|
@ -1001,6 +1001,7 @@ class FpValues {
|
||||
const FpValues* fp_values_;
|
||||
};
|
||||
|
||||
FpValues() : bit_chunks_(), offsets_() {}
|
||||
FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) {
|
||||
CHECK_EQ(chunks.size(), offsets.size() - 1);
|
||||
CHECK_EQ(chunks.size(), kTotalBitChunks);
|
||||
|
Loading…
Reference in New Issue
Block a user