Merge pull request #30434 from DavidNorman:resolve-exhaustive-test-compile-issue
PiperOrigin-RevId: 259404629
This commit is contained in:
commit
250436404a
@ -58,19 +58,19 @@ ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
|
ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
|
||||||
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
|
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
|
||||||
}
|
}
|
||||||
|
|
||||||
ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
|
ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
|
||||||
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
|
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
|
||||||
}
|
}
|
||||||
|
|
||||||
ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
|
ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
|
||||||
return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001};
|
return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001);
|
||||||
}
|
}
|
||||||
|
|
||||||
ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
|
ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
|
||||||
return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02};
|
return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -30,6 +30,26 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
using Eigen::half;
|
using Eigen::half;
|
||||||
|
|
||||||
|
namespace test_util {
|
||||||
|
template <int N>
|
||||||
|
struct IntegralTypeWithByteWidth {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IntegralTypeWithByteWidth<2> {
|
||||||
|
using type = uint16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IntegralTypeWithByteWidth<4> {
|
||||||
|
using type = uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IntegralTypeWithByteWidth<8> {
|
||||||
|
using type = uint64;
|
||||||
|
};
|
||||||
|
} // namespace test_util
|
||||||
|
|
||||||
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||||
public:
|
public:
|
||||||
struct ErrorSpec {
|
struct ErrorSpec {
|
||||||
@ -41,6 +61,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
// spec; this only covers the case when both `expected` and `actual` are
|
// spec; this only covers the case when both `expected` and `actual` are
|
||||||
// equal to 0.
|
// equal to 0.
|
||||||
bool strict_signed_zeros = false;
|
bool strict_signed_zeros = false;
|
||||||
|
|
||||||
|
ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// `ty` is the primitive type being tested.
|
// `ty` is the primitive type being tested.
|
||||||
@ -150,24 +172,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
|
||||||
struct IntegralTypeWithByteWidth {};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct IntegralTypeWithByteWidth<2> {
|
|
||||||
using type = uint16;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct IntegralTypeWithByteWidth<4> {
|
|
||||||
using type = uint32;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct IntegralTypeWithByteWidth<8> {
|
|
||||||
using type = uint64;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts part or all bits in an uint64 to the value of the floating point
|
// Converts part or all bits in an uint64 to the value of the floating point
|
||||||
// data type being tested.
|
// data type being tested.
|
||||||
//
|
//
|
||||||
@ -180,7 +184,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
// T is the type of the floating value represented by the `bits`.
|
// T is the type of the floating value represented by the `bits`.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T ConvertValue(uint64 bits) {
|
T ConvertValue(uint64 bits) {
|
||||||
using I = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
|
using I = typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||||
I used_bits = static_cast<I>(bits);
|
I used_bits = static_cast<I>(bits);
|
||||||
return BitCast<T>(used_bits);
|
return BitCast<T>(used_bits);
|
||||||
}
|
}
|
||||||
|
@ -364,7 +364,8 @@ class Exhaustive32BitOrLessUnaryTest
|
|||||||
// type being tested.
|
// type being tested.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void FillInput(Literal* input_literal) {
|
void FillInput(Literal* input_literal) {
|
||||||
using IntegralT = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
|
using IntegralT =
|
||||||
|
typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
|
||||||
int64 input_size = input_literal->element_count();
|
int64 input_size = input_literal->element_count();
|
||||||
int64 begin, end;
|
int64 begin, end;
|
||||||
std::tie(begin, end) = std::get<1>(GetParam());
|
std::tie(begin, end) = std::get<1>(GetParam());
|
||||||
|
Loading…
Reference in New Issue
Block a user