[XLA] Avoid defining a test that won't be instantiated for the binary being

built.

PiperOrigin-RevId: 267632386
This commit is contained in:
Bixia Zheng 2019-09-06 11:00:06 -07:00 committed by TensorFlower Gardener
parent d8551244b8
commit bac0574e43
4 changed files with 271 additions and 174 deletions

View File

@ -745,7 +745,9 @@ xla_test(
"no_oss", "no_oss",
], ],
deps = [ deps = [
":client_library_test_base",
":exhaustive_op_test_utils", ":exhaustive_op_test_utils",
"//tensorflow/compiler/xla:util",
], ],
) )
@ -765,7 +767,9 @@ xla_test(
"no_oss", "no_oss",
], ],
deps = [ deps = [
":client_library_test_base",
":exhaustive_op_test_utils", ":exhaustive_op_test_utils",
"//tensorflow/compiler/xla:util",
], ],
) )
@ -785,7 +789,9 @@ xla_test(
"no_oss", "no_oss",
], ],
deps = [ deps = [
":client_library_test_base",
":exhaustive_op_test_utils", ":exhaustive_op_test_utils",
"//tensorflow/compiler/xla:util",
], ],
) )
@ -1281,16 +1287,17 @@ xla_test(
srcs = ["slice_test.cc"], srcs = ["slice_test.cc"],
shard_count = 40, shard_count = 40,
deps = [ deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header", ":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core/platform:types",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",

View File

@ -88,30 +88,41 @@ inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
}; };
} }
#define XLA_TEST_16BIT(test_name, ...) \ #if defined(BINARY_TEST_TARGET_F16) && defined(BINARY_TEST_TARGET_BF16)
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ #error "Can't define both BINARY_TEST_TARGET_F16 and BINARY_TEST_TARGET_BF16"
__VA_ARGS__ \ #endif
#if defined(BINARY_TEST_TARGET_F16) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
#define BINARY_TEST_16BIT(test_name, ...) \
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
__VA_ARGS__
#elif defined(BINARY_TEST_TARGET_BF16) && defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
#define BINARY_TEST_16BIT(test_name, ...) \
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
__VA_ARGS__ __VA_ARGS__
#else
#define BINARY_TEST_16BIT(test_name, ...)
#endif
XLA_TEST_16BIT(Add, { BINARY_TEST_16BIT(Add, {
auto host_add = [](float x, float y) { return x + y; }; auto host_add = [](float x, float y) { return x + y; };
Run(AddEmptyBroadcastDimension(Add), host_add); Run(AddEmptyBroadcastDimension(Add), host_add);
}) })
XLA_TEST_16BIT(Sub, { BINARY_TEST_16BIT(Sub, {
auto host_sub = [](float x, float y) { return x - y; }; auto host_sub = [](float x, float y) { return x - y; };
Run(AddEmptyBroadcastDimension(Sub), host_sub); Run(AddEmptyBroadcastDimension(Sub), host_sub);
}) })
// TODO(bixia): Mul fails with bfloat16 on CPU. // TODO(bixia): Mul fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), { BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
auto host_mul = [](float x, float y) { return x * y; }; auto host_mul = [](float x, float y) { return x * y; };
Run(AddEmptyBroadcastDimension(Mul), host_mul); Run(AddEmptyBroadcastDimension(Mul), host_mul);
}) })
// TODO(bixia): Div fails with bfloat16 on CPU. // TODO(bixia): Div fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Div), { BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
auto host_div = [](float x, float y) { return x / y; }; auto host_div = [](float x, float y) { return x / y; };
Run(AddEmptyBroadcastDimension(Div), host_div); Run(AddEmptyBroadcastDimension(Div), host_div);
}) })
@ -146,19 +157,21 @@ T ReferenceMin(T x, T y) {
return std::min<T>(x, y); return std::min<T>(x, y);
} }
XLA_TEST_16BIT(Max, BINARY_TEST_16BIT(Max, {
{ Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); }) Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
})
XLA_TEST_16BIT(Min, BINARY_TEST_16BIT(Min, {
{ Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); }) Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
})
// TODO(bixia): Pow fails with bfloat16 on CPU. // TODO(bixia): Pow fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Pow), BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow),
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); }) { Run(AddEmptyBroadcastDimension(Pow), std::powf); })
// TODO(bixia): Atan2 fails with bfloat16 on CPU. // TODO(bixia): Atan2 fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2), BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); }) { Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
#if defined(BINARY_TEST_TARGET_F16) #if defined(BINARY_TEST_TARGET_F16)
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
@ -224,35 +237,43 @@ class Exhaustive32BitOrMoreBinaryTest
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>; using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>; using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
XLA_TEST_P(ExhaustiveF32BinaryTest, Add) { #if defined(BINARY_TEST_TARGET_F32)
#define BINARY_TEST_FLOAT_32(test_name, ...) \
XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_FLOAT_32(test_name, ...)
#endif
BINARY_TEST_FLOAT_32(Add, {
auto host_add = [](float x, float y) { return x + y; }; auto host_add = [](float x, float y) { return x + y; };
Run(AddEmptyBroadcastDimension(Add), host_add); Run(AddEmptyBroadcastDimension(Add), host_add);
} })
XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) { BINARY_TEST_FLOAT_32(Sub, {
auto host_sub = [](float x, float y) { return x - y; }; auto host_sub = [](float x, float y) { return x - y; };
Run(AddEmptyBroadcastDimension(Sub), host_sub); Run(AddEmptyBroadcastDimension(Sub), host_sub);
} })
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) { BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Mul), {
auto host_mul = [](float x, float y) { return x * y; }; auto host_mul = [](float x, float y) { return x * y; };
Run(AddEmptyBroadcastDimension(Mul), host_mul); Run(AddEmptyBroadcastDimension(Mul), host_mul);
} })
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) { BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Div), {
auto host_div = [](float x, float y) { return x / y; }; auto host_div = [](float x, float y) { return x / y; };
Run(AddEmptyBroadcastDimension(Div), host_div); Run(AddEmptyBroadcastDimension(Div), host_div);
} })
XLA_TEST_P(ExhaustiveF32BinaryTest, Max) { BINARY_TEST_FLOAT_32(Max, {
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
} })
XLA_TEST_P(ExhaustiveF32BinaryTest, Min) { BINARY_TEST_FLOAT_32(Min, {
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
} })
// It is more convenient to implement Abs(complex) as a binary op than a unary // It is more convenient to implement Abs(complex) as a binary op than a unary
// op, as the operations we currently support all have the same data type for // op, as the operations we currently support all have the same data type for
@ -261,16 +282,14 @@ XLA_TEST_P(ExhaustiveF32BinaryTest, Min) {
// implement Abs(complex) as unary conveniently. // implement Abs(complex) as unary conveniently.
// //
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) { BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), {
auto host_abs_complex = [](float x, float y) { auto host_abs_complex = [](float x, float y) {
return std::abs(std::complex<float>(x, y)); return std::abs(std::complex<float>(x, y));
}; };
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
Run(device_abs_complex, host_abs_complex); Run(device_abs_complex, host_abs_complex);
} })
#if defined(BINARY_TEST_TARGET_F32)
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF32BinaryTest, SpecialValues, ExhaustiveF32BinaryTest,
@ -307,51 +326,55 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn( ::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000)))); GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
#if defined(BINARY_TEST_TARGET_F64) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
#define BINARY_TEST_FLOAT_64(test_name, ...) \
XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_FLOAT_64(test_name, ...)
#endif #endif
XLA_TEST_P(ExhaustiveF64BinaryTest, Add) { BINARY_TEST_FLOAT_64(Add, {
auto host_add = [](double x, double y) { return x + y; }; auto host_add = [](double x, double y) { return x + y; };
Run(AddEmptyBroadcastDimension(Add), host_add); Run(AddEmptyBroadcastDimension(Add), host_add);
} })
XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) { BINARY_TEST_FLOAT_64(Sub, {
auto host_sub = [](double x, double y) { return x - y; }; auto host_sub = [](double x, double y) { return x - y; };
Run(AddEmptyBroadcastDimension(Sub), host_sub); Run(AddEmptyBroadcastDimension(Sub), host_sub);
} })
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) { BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Mul), {
auto host_mul = [](double x, double y) { return x * y; }; auto host_mul = [](double x, double y) { return x * y; };
Run(AddEmptyBroadcastDimension(Mul), host_mul); Run(AddEmptyBroadcastDimension(Mul), host_mul);
} })
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) { BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Div), {
auto host_div = [](double x, double y) { return x / y; }; auto host_div = [](double x, double y) { return x / y; };
Run(AddEmptyBroadcastDimension(Div), host_div); Run(AddEmptyBroadcastDimension(Div), host_div);
} })
XLA_TEST_P(ExhaustiveF64BinaryTest, Max) { BINARY_TEST_FLOAT_64(Max, {
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>); Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>);
} })
XLA_TEST_P(ExhaustiveF64BinaryTest, Min) { BINARY_TEST_FLOAT_64(Min, {
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>); Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
} })
// TODO(bixia): Need to investigate the failure on CPU and file bugs. // TODO(bixia): Need to investigate the failure on CPU and file bugs.
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) { BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), {
auto host_abs_complex = [](double x, double y) { auto host_abs_complex = [](double x, double y) {
return std::abs(std::complex<double>(x, y)); return std::abs(std::complex<double>(x, y));
}; };
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
Run(device_abs_complex, host_abs_complex); Run(device_abs_complex, host_abs_complex);
} })
#if defined(BINARY_TEST_TARGET_F64)
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF64BinaryTest, SpecialValues, ExhaustiveF64BinaryTest,
::testing::Combine( ::testing::Combine(
@ -385,8 +408,6 @@ INSTANTIATE_TEST_SUITE_P(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)), GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
::testing::ValuesIn( ::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)))); GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
#endif
#endif
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h" #include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#ifdef __FAST_MATH__ #ifdef __FAST_MATH__
#error "Can't be compiled with fast math on" #error "Can't be compiled with fast math on"
@ -211,15 +213,54 @@ typedef Exhaustive32BitOrLessUnaryTest<F32> ExhaustiveF32UnaryTest;
typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest; typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest; typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
#define XLA_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \ #if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ #define NEED_UNARY_F32 true
__VA_ARGS__ \ #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ #define NEED_UNARY_F16 true
__VA_ARGS__ \ #else
XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ #define NEED_UNARY_F16 false
__VA_ARGS__ #endif
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
#define NEED_UNARY_BF16 true
#else
#define NEED_UNARY_BF16 false
#endif
#else
#define NEED_UNARY_F32 false
#define NEED_UNARY_F16 false
#define NEED_UNARY_BF16 false
#endif
XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, { #if NEED_UNARY_F32
#define UNARY_TEST_F32(test_name, ...) \
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_F32(test_name, ...)
#endif
#if NEED_UNARY_F16
#define UNARY_TEST_F16(test_name, ...) \
XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_F16(test_name, ...)
#endif
#if NEED_UNARY_BF16
#define UNARY_TEST_BF16(test_name, ...) \
XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_BF16(test_name, ...)
#endif
#define UNARY_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \
UNARY_TEST_F32(test_name, __VA_ARGS__) \
UNARY_TEST_F16(test_name, __VA_ARGS__) \
UNARY_TEST_BF16(test_name, __VA_ARGS__)
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
@ -227,7 +268,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, {
Run(Log, std::log, error_spec_gen); Run(Log, std::log, error_spec_gen);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
@ -235,7 +276,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
Run(Log1p, std::log1p, error_spec_gen); Run(Log1p, std::log1p, error_spec_gen);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
// When x < -105, the true value of exp(x) is smaller than the smallest F32, // When x < -105, the true value of exp(x) is smaller than the smallest F32,
// so exp(x) should return exactly 0. We want our implementation of exp to // so exp(x) should return exactly 0. We want our implementation of exp to
// return exactly 0 as well, as not doing so implies either that our // return exactly 0 as well, as not doing so implies either that our
@ -266,7 +307,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
} }
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (ty_ == F32) { if (ty_ == F32) {
error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; }; error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
@ -292,7 +333,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
// It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
// this *did* find a bug, namely that some backends were assuming sqrt(x) == // this *did* find a bug, namely that some backends were assuming sqrt(x) ==
// pow(x, 0.5), but this is not true for x == -inf. // pow(x, 0.5), but this is not true for x == -inf.
XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); }; EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); };
// TODO(b/123837116): Enable the test for all values after fixing the bug. // TODO(b/123837116): Enable the test for all values after fixing the bug.
if (platform_ != "Host" && platform_ != "CUDA") { if (platform_ != "Host" && platform_ != "CUDA") {
@ -306,12 +347,12 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, {
Run( Run(
Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "Host" || platform_ == "CUDA") { if (platform_ == "Host" || platform_ == "CUDA") {
error_spec_gen = +[](NativeT x) { error_spec_gen = +[](NativeT x) {
@ -349,11 +390,11 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
XLA_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
// Our cosh implementation incorrectly overflows to inf for +/-89.4159851. // Our cosh implementation incorrectly overflows to inf for +/-89.4159851.
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
// max-float, so we deem this acceptable. // max-float, so we deem this acceptable.
@ -374,7 +415,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
Run(Cosh, host_cosh); Run(Cosh, host_cosh);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
// Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851. // Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851.
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
// max-float, so we deem this acceptable. // max-float, so we deem this acceptable.
@ -395,7 +436,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
Run(Sinh, host_sinh); Run(Sinh, host_sinh);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Tanh, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Tanh, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") { if (platform_ == "CUDA") {
error_spec_gen = +[](NativeT x) { error_spec_gen = +[](NativeT x) {
@ -429,62 +470,68 @@ void Exhaustive32BitOrLessUnaryTest<T>::SetParamsForSinCosTan() {
} }
} }
XLA_TEST_P(ExhaustiveF32UnaryTest, Cos) { UNARY_TEST_F32(Cos, {
SetParamsForSinCosTan(); SetParamsForSinCosTan();
Run( Run(
Cos, std::cos, +[](NativeT) { Cos, std::cos, +[](NativeT) {
return ErrorSpec{0.001, 0.001}; return ErrorSpec{0.001, 0.001};
}); });
} })
XLA_TEST_P(ExhaustiveF16UnaryTest, Cos) {
SetParamsForSinCosTan();
Run(Cos, std::cos);
}
XLA_TEST_P(ExhaustiveBF16UnaryTest, Cos) {
SetParamsForSinCosTan();
Run(Cos, std::cos);
}
XLA_TEST_P(ExhaustiveF32UnaryTest, Sin) { UNARY_TEST_F16(Cos, {
SetParamsForSinCosTan();
Run(Cos, std::cos);
})
UNARY_TEST_BF16(Cos, {
SetParamsForSinCosTan();
Run(Cos, std::cos);
})
UNARY_TEST_F32(Sin, {
SetParamsForSinCosTan(); SetParamsForSinCosTan();
Run( Run(
Sin, std::sin, +[](NativeT) { Sin, std::sin, +[](NativeT) {
return ErrorSpec{0.001, 0.001}; return ErrorSpec{0.001, 0.001};
}); });
} })
XLA_TEST_P(ExhaustiveF16UnaryTest, Sin) {
SetParamsForSinCosTan();
Run(Sin, std::sin);
}
XLA_TEST_P(ExhaustiveBF16UnaryTest, Sin) {
SetParamsForSinCosTan();
Run(Sin, std::sin);
}
XLA_TEST_P(ExhaustiveF32UnaryTest, Tan) { UNARY_TEST_F16(Sin, {
SetParamsForSinCosTan();
Run(Sin, std::sin);
})
UNARY_TEST_BF16(Sin, {
SetParamsForSinCosTan();
Run(Sin, std::sin);
})
UNARY_TEST_F32(Tan, {
SetParamsForSinCosTan(); SetParamsForSinCosTan();
Run( Run(
Tan, std::tan, +[](NativeT) { Tan, std::tan, +[](NativeT) {
return ErrorSpec{0.001, 0.001}; return ErrorSpec{0.001, 0.001};
}); });
} })
XLA_TEST_P(ExhaustiveF16UnaryTest, Tan) {
UNARY_TEST_F16(Tan, {
SetParamsForSinCosTan(); SetParamsForSinCosTan();
Run(Tan, std::tan); Run(Tan, std::tan);
} })
XLA_TEST_P(ExhaustiveBF16UnaryTest, Tan) {
UNARY_TEST_BF16(Tan, {
SetParamsForSinCosTan(); SetParamsForSinCosTan();
Run(Tan, std::tan); Run(Tan, std::tan);
} })
// TODO(jlebar): Enable these. // TODO(jlebar): Enable these.
// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); } // UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); }
// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); } // UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); }
XLA_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ != "Host" && platform_ != "CUDA") { if (platform_ != "Host" && platform_ != "CUDA") {
// TODO(b/123956399): This is a fairly high error, significantly higher than // TODO(b/123956399): This is a fairly high error, significantly higher than
@ -514,7 +561,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
} }
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, { UNARY_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
// Our implementation gets within 0.0001 rel error except for ~20 denormal // Our implementation gets within 0.0001 rel error except for ~20 denormal
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma. // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
@ -545,9 +592,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
Run(Lgamma, host_lgamma, error_spec_gen); Run(Lgamma, host_lgamma, error_spec_gen);
}) })
XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); }) UNARY_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges())); ::testing::ValuesIn(CreateExhaustiveF32Ranges()));
@ -562,8 +607,6 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest,
::testing::Values(std::make_pair(0, 1 << 16))); ::testing::Values(std::make_pair(0, 1 << 16)));
#endif #endif
#endif
// Exhaustive test for unary operations for double. // Exhaustive test for unary operations for double.
// //
// Test parameter is a tuple containing // Test parameter is a tuple containing
@ -594,42 +637,51 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
} }
}; };
XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); } #if defined(UNARY_TEST_TARGET_F64) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
#define UNARY_TEST_FLOAT_64(test_name, ...) \
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_FLOAT_64(test_name, ...)
#endif
XLA_TEST_P(ExhaustiveF64UnaryTest, Log1p) { Run(Log1p, std::log1p); } UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Exp) { Run(Exp, std::exp); } UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Expm1) { Run(Expm1, std::expm1); } UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
// TODO(b/138385863): Turn on the test for GPU after fixing the bug. // TODO(b/138385863): Turn on the test for GPU after fixing the bug.
XLA_TEST_P(ExhaustiveF64UnaryTest, DISABLED_ON_GPU(PowOneHalf)) { UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
+[](double x) { return std::pow(x, 0.5); }); +[](double x) { return std::pow(x, 0.5); });
} })
XLA_TEST_P(ExhaustiveF64UnaryTest, Rsqrt) { UNARY_TEST_FLOAT_64(Rsqrt, {
Run( Run(
Rsqrt, +[](double x) { return 1 / std::sqrt(x); }); Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
} })
XLA_TEST_P(ExhaustiveF64UnaryTest, Sqrt) { Run(Sqrt, std::sqrt); } UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Acosh) { Run(Acosh, std::acosh); } UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Asinh) { Run(Asinh, std::asinh); } UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Atanh) { Run(Atanh, std::atanh); } UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Acos) { Run(Acos, std::acos); } UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Asin) { Run(Asin, std::asin); } UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Cosh) { Run(Cosh, std::cosh); } UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Sinh) { Run(Sinh, std::sinh); } UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) { UNARY_TEST_FLOAT_64(Tanh, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") { if (platform_ == "CUDA") {
error_spec_gen = +[](NativeT x) { error_spec_gen = +[](NativeT x) {
@ -639,26 +691,24 @@ XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) {
}; };
} }
Run(Tanh, std::tanh, error_spec_gen); Run(Tanh, std::tanh, error_spec_gen);
} })
XLA_TEST_P(ExhaustiveF64UnaryTest, Cos) { Run(Cos, std::cos); } UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Sin) { Run(Sin, std::sin); } UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Tan) { Run(Tan, std::tan); } UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Round) { Run(Round, std::round); } UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
XLA_TEST_P(ExhaustiveF64UnaryTest, Erf) { UNARY_TEST_FLOAT_64(Erf, {
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
} })
XLA_TEST_P(ExhaustiveF64UnaryTest, Erfc) { UNARY_TEST_FLOAT_64(Erfc, {
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
} })
#if defined(UNARY_TEST_TARGET_F64)
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF64UnaryTest, SpecialValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())); ::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
@ -672,8 +722,6 @@ INSTANTIATE_TEST_SUITE_P(
LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest, LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>( ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
4000000000ull, 16000000))); 4000000000ull, 16000000)));
#endif
#endif
// T is the Primitive Type of the complex number // T is the Primitive Type of the complex number
// Test parameter is a tuple containing // Test parameter is a tuple containing
@ -741,30 +789,38 @@ class ExhaustiveComplexUnaryTestBase
typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest; typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest; typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. #if defined(UNARY_TEST_TARGET_COMPLEX)
XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) { #define UNARY_TEST_COMPLEX_64(test_name, ...) \
Run(Log, [](complex64 x) { return std::log<float>(x); }); XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
} __VA_ARGS__
#else
#define UNARY_TEST_COMPLEX_64(test_name, ...)
#endif
XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) { // TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
Run(Log, [](complex64 x) { return std::log<float>(x); });
})
UNARY_TEST_COMPLEX_64(Sqrt, {
Run(Sqrt, [](complex64 x) { Run(Sqrt, [](complex64 x) {
return static_cast<complex64>( return static_cast<complex64>(
std::sqrt<double>(static_cast<complex128>(x))); std::sqrt<double>(static_cast<complex128>(x)));
}); });
} })
XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) { UNARY_TEST_COMPLEX_64(Rsqrt, {
Run(Rsqrt, [](complex64 x) { Run(Rsqrt, [](complex64 x) {
return static_cast<complex64>( return static_cast<complex64>(
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x))); complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
}); });
} })
// The current libc++ implementation of the complex tanh function provides // The current libc++ implementation of the complex tanh function provides
// less accurate results when the denomenator of a complex tanh is small, due // less accurate results when the denomenator of a complex tanh is small, due
// to floating point precision loss. To avoid this issue for complex64 numbers, // to floating point precision loss. To avoid this issue for complex64 numbers,
// we cast it to and from a complex128 when computing tanh. // we cast it to and from a complex128 when computing tanh.
XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) { UNARY_TEST_COMPLEX_64(Tanh, {
SetParamsForTanh(); SetParamsForTanh();
ErrorSpecGen error_spec_gen = +[](complex64 x) { ErrorSpecGen error_spec_gen = +[](complex64 x) {
// This implementation of Tanh becomes less accurate when the denominator // This implementation of Tanh becomes less accurate when the denominator
@ -781,9 +837,8 @@ XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) {
return static_cast<complex64>(std::tanh(static_cast<complex128>(x))); return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
}, },
error_spec_gen); error_spec_gen);
} })
#if defined(UNARY_TEST_TARGET_COMPLEX)
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
F32SpecialValues, ExhaustiveC64UnaryTest, F32SpecialValues, ExhaustiveC64UnaryTest,
::testing::Combine( ::testing::Combine(
@ -816,10 +871,17 @@ INSTANTIATE_TEST_SUITE_P(
4000)), 4000)),
::testing::ValuesIn( ::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000)))); GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
#if defined(UNARY_TEST_TARGET_COMPLEX) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_COMPLEX_128(test_name, ...)
#endif #endif
UNARY_TEST_COMPLEX_128(Log, {
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
// TODO(b/138578313): Enable the test for all values after fixing the bug. // TODO(b/138578313): Enable the test for all values after fixing the bug.
known_incorrect_fn_ = [&](int64 v) { known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v); double f = this->ConvertValue(v);
@ -827,18 +889,18 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
std::abs(f) < 1.0e-300; std::abs(f) < 1.0e-300;
}; };
Run(Log, [](complex128 x) { return std::log<double>(x); }); Run(Log, [](complex128 x) { return std::log<double>(x); });
} })
XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) { UNARY_TEST_COMPLEX_128(Sqrt, {
// Similar to the Tanh bug. // Similar to the Tanh bug.
known_incorrect_fn_ = [&](int64 v) { known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v); double f = this->ConvertValue(v);
return std::abs(f) > std::numeric_limits<double>::max() / 2; return std::abs(f) > std::numeric_limits<double>::max() / 2;
}; };
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); }); Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
} })
XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) { UNARY_TEST_COMPLEX_128(Rsqrt, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") { if (platform_ == "CUDA") {
// Edge case on CUDA backend where the Log of a complex number made up of // Edge case on CUDA backend where the Log of a complex number made up of
@ -856,16 +918,14 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) {
Rsqrt, Rsqrt,
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); }, [](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
error_spec_gen); error_spec_gen);
} })
XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) { UNARY_TEST_COMPLEX_128(Tanh, {
SetParamsForTanh(); SetParamsForTanh();
Run( Run(
Tanh, +[](complex128 x) { return std::tanh(x); }); Tanh, +[](complex128 x) { return std::tanh(x); });
} })
#if defined(UNARY_TEST_TARGET_COMPLEX)
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveC128UnaryTest, SpecialValues, ExhaustiveC128UnaryTest,
::testing::Combine( ::testing::Combine(
@ -898,7 +958,5 @@ INSTANTIATE_TEST_SUITE_P(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)), GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
::testing::ValuesIn( ::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)))); GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
#endif
#endif
} // namespace xla } // namespace xla

View File

@ -259,17 +259,31 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run<float>(GetParam()); } // TODO(b/69425338): The following tests are disable on GPU because they use
// too much GPU memory.
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F32)) {
Run<float>(GetParam());
}
XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run<double>(GetParam()); } XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F64)) {
Run<double>(GetParam());
}
XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run<uint32>(GetParam()); } XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U32)) {
Run<uint32>(GetParam());
}
XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run<int32>(GetParam()); } XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S32)) {
Run<int32>(GetParam());
}
XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run<uint64>(GetParam()); } XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U64)) {
Run<uint64>(GetParam());
}
XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run<int64>(GetParam()); } XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S64)) {
Run<int64>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
@ -315,8 +329,6 @@ INSTANTIATE_TEST_CASE_P(
SliceR1TestDataToString SliceR1TestDataToString
); );
// TODO(b/69425338): This uses too much memory on GPU.
#ifndef XLA_TEST_BACKEND_GPU
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
SliceR1TestBigSlicesInstantiation, SliceR1TestBigSlicesInstantiation,
SliceR1LargeTest, SliceR1LargeTest,
@ -330,7 +342,6 @@ INSTANTIATE_TEST_CASE_P(
), ),
SliceR1TestDataToString SliceR1TestDataToString
); );
#endif
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
SliceStridedR1TestInstantiation, SliceStridedR1TestInstantiation,