[XLA] Refactor dot_operation_test's parameterization

PiperOrigin-RevId: 352483336
Change-Id: Ib5d3adf1e91a495fa7adb18d47ad99b94f99f9b4
This commit is contained in:
David Majnemer 2021-01-18 20:04:56 -08:00 committed by TensorFlower Gardener
parent 673079ce30
commit 4fcfcbe2c6
2 changed files with 27 additions and 19 deletions

View File

@ -43,26 +43,29 @@ class DotOperationTest : public ClientLibraryTestBase {
ErrorSpec error_spec_{0.0001, 1e-5}; ErrorSpec error_spec_{0.0001, 1e-5};
}; };
#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ using TypesF16F32 = ::testing::Types<
defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
using TypesF16F32 = ::testing::Types<float>; Eigen::half,
using TypesF16F32F64 = ::testing::Types<float>;
using TypesF16F32F64CF64 = ::testing::Types<float>;
#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
using TypesF16F32 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64 = ::testing::Types<Eigen::half, float, double>;
using TypesF16F32F64CF64 =
::testing::Types<Eigen::half, float, double, complex64>;
#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \
defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \
defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX)
using TypesF16F32 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64 = ::testing::Types<Eigen::half, float>;
using TypesF16F32F64CF64 = ::testing::Types<Eigen::half, float>;
#else
#error "Situation not handled yet"
#endif #endif
float>;
using TypesF16F32F64 = ::testing::Types<
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
Eigen::half,
#endif
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
double,
#endif
float>;
using TypesF16F32F64CF64 = ::testing::Types<
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
Eigen::half,
#endif
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
double, complex64,
#endif
float>;
// Check that we can safely pass an input tuple's elements to a dot operation. // Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {

View File

@ -477,6 +477,11 @@ void Exhaustive32BitOrLessUnaryTest<T>::SetParamsForSinCosTan() {
float f = static_cast<float>(BitCast<bfloat16>(static_cast<uint16>(v))); float f = static_cast<float>(BitCast<bfloat16>(static_cast<uint16>(v)));
return std::abs(f) > (1 << 13); return std::abs(f) > (1 << 13);
}; };
} else if (T == F16) {
this->known_incorrect_fn_ = [](int64 v) {
float f = static_cast<float>(BitCast<half>(static_cast<uint16>(v)));
return std::abs(f) > (1 << 13);
};
} }
} }