Merge pull request #30434 from DavidNorman:resolve-exhaustive-test-compile-issue
PiperOrigin-RevId: 259404629
This commit is contained in:
commit
250436404a
@ -58,19 +58,19 @@ ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() {
|
||||
|
||||
namespace {
|
||||
ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
|
||||
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
|
||||
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
|
||||
}
|
||||
|
||||
ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
|
||||
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
|
||||
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
|
||||
}
|
||||
|
||||
ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
|
||||
return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001};
|
||||
return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001);
|
||||
}
|
||||
|
||||
ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
|
||||
return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02};
|
||||
return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -30,6 +30,26 @@ limitations under the License.
|
||||
namespace xla {
|
||||
using Eigen::half;
|
||||
|
||||
namespace test_util {
|
||||
template <int N>
|
||||
struct IntegralTypeWithByteWidth {};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<2> {
|
||||
using type = uint16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<4> {
|
||||
using type = uint32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<8> {
|
||||
using type = uint64;
|
||||
};
|
||||
} // namespace test_util
|
||||
|
||||
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
public:
|
||||
struct ErrorSpec {
|
||||
@ -41,6 +61,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
// spec; this only covers the case when both `expected` and `actual` are
|
||||
// equal to 0.
|
||||
bool strict_signed_zeros = false;
|
||||
|
||||
ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
|
||||
};
|
||||
|
||||
// `ty` is the primitive type being tested.
|
||||
@ -150,24 +172,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
struct IntegralTypeWithByteWidth {};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<2> {
|
||||
using type = uint16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<4> {
|
||||
using type = uint32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IntegralTypeWithByteWidth<8> {
|
||||
using type = uint64;
|
||||
};
|
||||
|
||||
// Converts part or all bits in an uint64 to the value of the floating point
|
||||
// data type being tested.
|
||||
//
|
||||
@ -180,7 +184,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
// T is the type of the floating value represented by the `bits`.
|
||||
template <typename T>
|
||||
T ConvertValue(uint64 bits) {
|
||||
using I = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||
using I = typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||
I used_bits = static_cast<I>(bits);
|
||||
return BitCast<T>(used_bits);
|
||||
}
|
||||
|
@ -364,7 +364,8 @@ class Exhaustive32BitOrLessUnaryTest
|
||||
// type being tested.
|
||||
template <typename T>
|
||||
void FillInput(Literal* input_literal) {
|
||||
using IntegralT = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||
using IntegralT =
|
||||
typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||
int64 input_size = input_literal->element_count();
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = std::get<1>(GetParam());
|
||||
|
Loading…
Reference in New Issue
Block a user