Merge pull request #30434 from DavidNorman:resolve-exhaustive-test-compile-issue

PiperOrigin-RevId: 259404629
This commit is contained in:
TensorFlower Gardener 2019-07-22 14:37:34 -07:00
commit 250436404a
3 changed files with 29 additions and 24 deletions

View File

@ -58,19 +58,19 @@ ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() {
namespace {
ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
}
ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001};
return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001);
}
ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001};
return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001);
}
ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) {
return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02};
return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02);
}
} // namespace

View File

@ -30,6 +30,26 @@ limitations under the License.
namespace xla {
using Eigen::half;
namespace test_util {
template <int N>
struct IntegralTypeWithByteWidth {};
template <>
struct IntegralTypeWithByteWidth<2> {
using type = uint16;
};
template <>
struct IntegralTypeWithByteWidth<4> {
using type = uint32;
};
template <>
struct IntegralTypeWithByteWidth<8> {
using type = uint64;
};
} // namespace test_util
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
public:
struct ErrorSpec {
@ -41,6 +61,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// spec; this only covers the case when both `expected` and `actual` are
// equal to 0.
bool strict_signed_zeros = false;
ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
};
// `ty` is the primitive type being tested.
@ -150,24 +172,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
}
}
template <int N>
struct IntegralTypeWithByteWidth {};
template <>
struct IntegralTypeWithByteWidth<2> {
using type = uint16;
};
template <>
struct IntegralTypeWithByteWidth<4> {
using type = uint32;
};
template <>
struct IntegralTypeWithByteWidth<8> {
using type = uint64;
};
// Converts part or all bits in an uint64 to the value of the floating point
// data type being tested.
//
@ -180,7 +184,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// T is the type of the floating value represented by the `bits`.
template <typename T>
T ConvertValue(uint64 bits) {
using I = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
using I = typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
I used_bits = static_cast<I>(bits);
return BitCast<T>(used_bits);
}

View File

@ -364,7 +364,8 @@ class Exhaustive32BitOrLessUnaryTest
// type being tested.
template <typename T>
void FillInput(Literal* input_literal) {
using IntegralT = typename IntegralTypeWithByteWidth<sizeof(T)>::type;
using IntegralT =
typename test_util::IntegralTypeWithByteWidth<sizeof(T)>::type;
int64 input_size = input_literal->element_count();
int64 begin, end;
std::tie(begin, end) = std::get<1>(GetParam());