[XLA] Refactor dot_operation_test's parameterization

PiperOrigin-RevId: 352483336
Change-Id: Ib5d3adf1e91a495fa7adb18d47ad99b94f99f9b4
This commit is contained in:
David Majnemer 2021-01-18 20:04:56 -08:00 committed by TensorFlower Gardener
parent 673079ce30
commit 4fcfcbe2c6
2 changed files with 27 additions and 19 deletions

View File

@ -43,26 +43,29 @@ class DotOperationTest : public ClientLibraryTestBase {
ErrorSpec error_spec_{0.0001, 1e-5};
};
#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \
defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
using TypesF16F32 = ::testing::Types<float>;
using TypesF16F32F64 = ::testing::Types<float>;
using TypesF16F32F64CF64 = ::testing::Types<float>;
#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
using TypesF16F32 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64 = ::testing::Types<Eigen::half, float, double>;
using TypesF16F32F64CF64 =
::testing::Types<Eigen::half, float, double, complex64>;
#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \
defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \
defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX)
using TypesF16F32 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64CF64 = ::testing::Types<Eigen::half, float>;
#else
#error "Situation not handled yet"
using TypesF16F32 = ::testing::Types<
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
Eigen::half,
#endif
float>;
using TypesF16F32F64 = ::testing::Types<
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
Eigen::half,
#endif
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
double,
#endif
float>;
using TypesF16F32F64CF64 = ::testing::Types<
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
Eigen::half,
#endif
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
double, complex64,
#endif
float>;
// Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {

View File

@ -477,6 +477,11 @@ void Exhaustive32BitOrLessUnaryTest<T>::SetParamsForSinCosTan() {
float f = static_cast<float>(BitCast<bfloat16>(static_cast<uint16>(v)));
return std::abs(f) > (1 << 13);
};
} else if (T == F16) {
this->known_incorrect_fn_ = [](int64 v) {
float f = static_cast<float>(BitCast<half>(static_cast<uint16>(v)));
return std::abs(f) > (1 << 13);
};
}
}