Merge pull request #30434 from DavidNorman:resolve-exhaustive-test-compile-issue

PiperOrigin-RevId: 259404629
This commit is contained in:
TensorFlower Gardener 2019-07-22 14:37:34 -07:00
commit 250436404a
3 changed files with 29 additions and 24 deletions

View File

@ -58,19 +58,19 @@ ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() {
namespace { namespace {
ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) { ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
} }
ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) { ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
} }
ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) { ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001);
} }
ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) { ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02}; return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02);
} }
} // namespace } // namespace

View File

@ -30,6 +30,26 @@ limitations under the License.
namespace xla { namespace xla {
using Eigen::half; using Eigen::half;
namespace test_util {
template <int N>
struct IntegralTypeWithByteWidth {};
template <>
struct IntegralTypeWithByteWidth<2> {
using type = uint16;
};
template <>
struct IntegralTypeWithByteWidth<4> {
using type = uint32;
};
template <>
struct IntegralTypeWithByteWidth<8> {
using type = uint64;
};
} // namespace test_util
class ExhaustiveOpTestBase : public ClientLibraryTestBase { class ExhaustiveOpTestBase : public ClientLibraryTestBase {
public: public:
struct ErrorSpec { struct ErrorSpec {
@ -41,6 +61,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// spec; this only covers the case when both `expected` and `actual` are // spec; this only covers the case when both `expected` and `actual` are
// equal to 0. // equal to 0.
bool strict_signed_zeros = false; bool strict_signed_zeros = false;
ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
}; };
// `ty` is the primitive type being tested. // `ty` is the primitive type being tested.
@ -150,24 +172,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
} }
} }
template <int N>
struct IntegralTypeWithByteWidth {};
template <>
struct IntegralTypeWithByteWidth<2> {
using type = uint16;
};
template <>
struct IntegralTypeWithByteWidth<4> {
using type = uint32;
};
template <>
struct IntegralTypeWithByteWidth<8> {
using type = uint64;
};
// Converts part or all bits in an uint64 to the value of the floating point // Converts part or all bits in an uint64 to the value of the floating point
// data type being tested. // data type being tested.
// //
@ -180,7 +184,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// T is the type of the floating value represented by the `bits`. // T is the type of the floating value represented by the `bits`.
template <typename T> template <typename T>
T ConvertValue(uint64 bits) { T ConvertValue(uint64 bits) {
using I = typename IntegralTypeWithByteWidth<sizeof(T)>::type; using I = typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
I used_bits = static_cast<I>(bits); I used_bits = static_cast<I>(bits);
return BitCast<T>(used_bits); return BitCast<T>(used_bits);
} }

View File

@ -364,7 +364,8 @@ class Exhaustive32BitOrLessUnaryTest
// type being tested. // type being tested.
template <typename T> template <typename T>
void FillInput(Literal* input_literal) { void FillInput(Literal* input_literal) {
using IntegralT = typename IntegralTypeWithByteWidth<sizeof(T)>::type; using IntegralT =
typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
int64 input_size = input_literal->element_count(); int64 input_size = input_literal->element_count();
int64 begin, end; int64 begin, end;
std::tie(begin, end) = std::get<1>(GetParam()); std::tie(begin, end) = std::get<1>(GetParam());