diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 570dbd744bc..2ea7ca57bd1 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -43,26 +43,29 @@ class DotOperationTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001, 1e-5}; }; -#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ - defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -using TypesF16F32 = ::testing::Types; -using TypesF16F32F64 = ::testing::Types; -using TypesF16F32F64CF64 = ::testing::Types; -#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ - !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -using TypesF16F32 = ::testing::Types; -using TypesF16F32F64 = ::testing::Types; -using TypesF16F32F64CF64 = - ::testing::Types; -#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ - defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ - defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX) -using TypesF16F32 = ::testing::Types; -using TypesF16F32F64 = ::testing::Types; -using TypesF16F32F64CF64 = ::testing::Types; -#else -#error "Situation not handled yet" +using TypesF16F32 = ::testing::Types< +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) + Eigen::half, #endif + float>; + +using TypesF16F32F64 = ::testing::Types< +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) + Eigen::half, +#endif +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) + double, +#endif + float>; + +using TypesF16F32F64CF64 = ::testing::Types< +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) + Eigen::half, +#endif +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) + double, complex64, +#endif + float>; // Check that we can safely pass an input tuple's elements to a dot operation. XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc index 8c0e984bf30..b88c7481765 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc @@ -477,6 +477,11 @@ void Exhaustive32BitOrLessUnaryTest::SetParamsForSinCosTan() { float f = static_cast(BitCast(static_cast(v))); return std::abs(f) > (1 << 13); }; + } else if (T == F16) { + this->known_incorrect_fn_ = [](int64 v) { + float f = static_cast(BitCast(static_cast(v))); + return std::abs(f) > (1 << 13); + }; } }