diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc index 465da47faeb..02273d7debd 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc @@ -58,19 +58,19 @@ ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() { namespace { ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; + return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001); } ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; + return ExhaustiveOpTestBase::ErrorSpec(0.0001, 0.0001); } ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; + return ExhaustiveOpTestBase::ErrorSpec(0.001, 0.001); } ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02}; + return ExhaustiveOpTestBase::ErrorSpec(0.002, 0.02); } } // namespace diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h index ad42779ddc7..be16fddc756 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h @@ -30,6 +30,26 @@ limitations under the License. namespace xla { using Eigen::half; +namespace test_util { +template +struct IntegralTypeWithByteWidth {}; + +template <> +struct IntegralTypeWithByteWidth<2> { + using type = uint16; +}; + +template <> +struct IntegralTypeWithByteWidth<4> { + using type = uint32; +}; + +template <> +struct IntegralTypeWithByteWidth<8> { + using type = uint64; +}; +} // namespace test_util + class ExhaustiveOpTestBase : public ClientLibraryTestBase { public: struct ErrorSpec { @@ -41,6 +61,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // spec; this only covers the case when both `expected` and `actual` are // equal to 0. bool strict_signed_zeros = false; + + ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {} }; // `ty` is the primitive type being tested. @@ -150,24 +172,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { } } - template - struct IntegralTypeWithByteWidth {}; - - template <> - struct IntegralTypeWithByteWidth<2> { - using type = uint16; - }; - - template <> - struct IntegralTypeWithByteWidth<4> { - using type = uint32; - }; - - template <> - struct IntegralTypeWithByteWidth<8> { - using type = uint64; - }; - // Converts part or all bits in an uint64 to the value of the floating point // data type being tested. // @@ -180,7 +184,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // T is the type of the floating value represented by the `bits`. template T ConvertValue(uint64 bits) { - using I = typename IntegralTypeWithByteWidth::type; + using I = typename test_util::IntegralTypeWithByteWidth::type; I used_bits = static_cast(bits); return BitCast(used_bits); } diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc index 5f82af95245..4019a5f78f8 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc @@ -364,7 +364,8 @@ class Exhaustive32BitOrLessUnaryTest // type being tested. template void FillInput(Literal* input_literal) { - using IntegralT = typename IntegralTypeWithByteWidth::type; + using IntegralT = + typename test_util::IntegralTypeWithByteWidth::type; int64 input_size = input_literal->element_count(); int64 begin, end; std::tie(begin, end) = std::get<1>(GetParam());