[XLA] Avoid defining a test that won't be instantiated for the binary being
built. PiperOrigin-RevId: 267632386
This commit is contained in:
parent
d8551244b8
commit
bac0574e43
@ -745,7 +745,9 @@ xla_test(
|
|||||||
"no_oss",
|
"no_oss",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
":exhaustive_op_test_utils",
|
":exhaustive_op_test_utils",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -765,7 +767,9 @@ xla_test(
|
|||||||
"no_oss",
|
"no_oss",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
":exhaustive_op_test_utils",
|
":exhaustive_op_test_utils",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -785,7 +789,9 @@ xla_test(
|
|||||||
"no_oss",
|
"no_oss",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
":exhaustive_op_test_utils",
|
":exhaustive_op_test_utils",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1281,16 +1287,17 @@ xla_test(
|
|||||||
srcs = ["slice_test.cc"],
|
srcs = ["slice_test.cc"],
|
||||||
shard_count = 40,
|
shard_count = 40,
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
|
":literal_test_util",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
|
":xla_internal_test_main",
|
||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:reference_util",
|
"//tensorflow/compiler/xla:reference_util",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -88,30 +88,41 @@ inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define XLA_TEST_16BIT(test_name, ...) \
|
#if defined(BINARY_TEST_TARGET_F16) && defined(BINARY_TEST_TARGET_BF16)
|
||||||
|
#error "Can't define both BINARY_TEST_TARGET_F16 and BINARY_TEST_TARGET_BF16"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(BINARY_TEST_TARGET_F16) && \
|
||||||
|
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||||
|
#define BINARY_TEST_16BIT(test_name, ...) \
|
||||||
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
|
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__
|
||||||
|
#elif defined(BINARY_TEST_TARGET_BF16) && defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||||
|
#define BINARY_TEST_16BIT(test_name, ...) \
|
||||||
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
|
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
|
||||||
__VA_ARGS__
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define BINARY_TEST_16BIT(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
XLA_TEST_16BIT(Add, {
|
BINARY_TEST_16BIT(Add, {
|
||||||
auto host_add = [](float x, float y) { return x + y; };
|
auto host_add = [](float x, float y) { return x + y; };
|
||||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_16BIT(Sub, {
|
BINARY_TEST_16BIT(Sub, {
|
||||||
auto host_sub = [](float x, float y) { return x - y; };
|
auto host_sub = [](float x, float y) { return x - y; };
|
||||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||||
})
|
})
|
||||||
|
|
||||||
// TODO(bixia): Mul fails with bfloat16 on CPU.
|
// TODO(bixia): Mul fails with bfloat16 on CPU.
|
||||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), {
|
BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
|
||||||
auto host_mul = [](float x, float y) { return x * y; };
|
auto host_mul = [](float x, float y) { return x * y; };
|
||||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||||
})
|
})
|
||||||
|
|
||||||
// TODO(bixia): Div fails with bfloat16 on CPU.
|
// TODO(bixia): Div fails with bfloat16 on CPU.
|
||||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Div), {
|
BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
|
||||||
auto host_div = [](float x, float y) { return x / y; };
|
auto host_div = [](float x, float y) { return x / y; };
|
||||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||||
})
|
})
|
||||||
@ -146,18 +157,20 @@ T ReferenceMin(T x, T y) {
|
|||||||
return std::min<T>(x, y);
|
return std::min<T>(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_16BIT(Max,
|
BINARY_TEST_16BIT(Max, {
|
||||||
{ Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); })
|
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
||||||
|
})
|
||||||
|
|
||||||
XLA_TEST_16BIT(Min,
|
BINARY_TEST_16BIT(Min, {
|
||||||
{ Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); })
|
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
||||||
|
})
|
||||||
|
|
||||||
// TODO(bixia): Pow fails with bfloat16 on CPU.
|
// TODO(bixia): Pow fails with bfloat16 on CPU.
|
||||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Pow),
|
BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow),
|
||||||
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
|
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
|
||||||
|
|
||||||
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
|
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
|
||||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2),
|
BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
|
||||||
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
|
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
|
||||||
|
|
||||||
#if defined(BINARY_TEST_TARGET_F16)
|
#if defined(BINARY_TEST_TARGET_F16)
|
||||||
@ -224,35 +237,43 @@ class Exhaustive32BitOrMoreBinaryTest
|
|||||||
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
|
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
|
||||||
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
|
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Add) {
|
#if defined(BINARY_TEST_TARGET_F32)
|
||||||
|
#define BINARY_TEST_FLOAT_32(test_name, ...) \
|
||||||
|
XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \
|
||||||
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define BINARY_TEST_FLOAT_32(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
BINARY_TEST_FLOAT_32(Add, {
|
||||||
auto host_add = [](float x, float y) { return x + y; };
|
auto host_add = [](float x, float y) { return x + y; };
|
||||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) {
|
BINARY_TEST_FLOAT_32(Sub, {
|
||||||
auto host_sub = [](float x, float y) { return x - y; };
|
auto host_sub = [](float x, float y) { return x - y; };
|
||||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) {
|
BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Mul), {
|
||||||
auto host_mul = [](float x, float y) { return x * y; };
|
auto host_mul = [](float x, float y) { return x * y; };
|
||||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) {
|
BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(Div), {
|
||||||
auto host_div = [](float x, float y) { return x / y; };
|
auto host_div = [](float x, float y) { return x / y; };
|
||||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Max) {
|
BINARY_TEST_FLOAT_32(Max, {
|
||||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, Min) {
|
BINARY_TEST_FLOAT_32(Min, {
|
||||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
||||||
}
|
})
|
||||||
|
|
||||||
// It is more convenient to implement Abs(complex) as a binary op than a unary
|
// It is more convenient to implement Abs(complex) as a binary op than a unary
|
||||||
// op, as the operations we currently support all have the same data type for
|
// op, as the operations we currently support all have the same data type for
|
||||||
@ -261,16 +282,14 @@ XLA_TEST_P(ExhaustiveF32BinaryTest, Min) {
|
|||||||
// implement Abs(complex) as unary conveniently.
|
// implement Abs(complex) as unary conveniently.
|
||||||
//
|
//
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
|
BINARY_TEST_FLOAT_32(DISABLED_ON_CPU(AbsComplex), {
|
||||||
auto host_abs_complex = [](float x, float y) {
|
auto host_abs_complex = [](float x, float y) {
|
||||||
return std::abs(std::complex<float>(x, y));
|
return std::abs(std::complex<float>(x, y));
|
||||||
};
|
};
|
||||||
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
||||||
|
|
||||||
Run(device_abs_complex, host_abs_complex);
|
Run(device_abs_complex, host_abs_complex);
|
||||||
}
|
})
|
||||||
|
|
||||||
#if defined(BINARY_TEST_TARGET_F32)
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SpecialValues, ExhaustiveF32BinaryTest,
|
SpecialValues, ExhaustiveF32BinaryTest,
|
||||||
@ -307,51 +326,55 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
::testing::ValuesIn(
|
::testing::ValuesIn(
|
||||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
|
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
|
||||||
|
|
||||||
|
#if defined(BINARY_TEST_TARGET_F64) && \
|
||||||
|
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||||
|
#define BINARY_TEST_FLOAT_64(test_name, ...) \
|
||||||
|
XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \
|
||||||
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define BINARY_TEST_FLOAT_64(test_name, ...)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Add) {
|
BINARY_TEST_FLOAT_64(Add, {
|
||||||
auto host_add = [](double x, double y) { return x + y; };
|
auto host_add = [](double x, double y) { return x + y; };
|
||||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) {
|
BINARY_TEST_FLOAT_64(Sub, {
|
||||||
auto host_sub = [](double x, double y) { return x - y; };
|
auto host_sub = [](double x, double y) { return x - y; };
|
||||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) {
|
BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Mul), {
|
||||||
auto host_mul = [](double x, double y) { return x * y; };
|
auto host_mul = [](double x, double y) { return x * y; };
|
||||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) {
|
BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(Div), {
|
||||||
auto host_div = [](double x, double y) { return x / y; };
|
auto host_div = [](double x, double y) { return x / y; };
|
||||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Max) {
|
BINARY_TEST_FLOAT_64(Max, {
|
||||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>);
|
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<double>);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, Min) {
|
BINARY_TEST_FLOAT_64(Min, {
|
||||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
|
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<double>);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
// TODO(bixia): Need to investigate the failure on CPU and file bugs.
|
||||||
XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) {
|
BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), {
|
||||||
auto host_abs_complex = [](double x, double y) {
|
auto host_abs_complex = [](double x, double y) {
|
||||||
return std::abs(std::complex<double>(x, y));
|
return std::abs(std::complex<double>(x, y));
|
||||||
};
|
};
|
||||||
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); };
|
||||||
|
|
||||||
Run(device_abs_complex, host_abs_complex);
|
Run(device_abs_complex, host_abs_complex);
|
||||||
}
|
})
|
||||||
|
|
||||||
#if defined(BINARY_TEST_TARGET_F64)
|
|
||||||
|
|
||||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SpecialValues, ExhaustiveF64BinaryTest,
|
SpecialValues, ExhaustiveF64BinaryTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
@ -385,8 +408,6 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||||
::testing::ValuesIn(
|
::testing::ValuesIn(
|
||||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
#ifdef __FAST_MATH__
|
#ifdef __FAST_MATH__
|
||||||
#error "Can't be compiled with fast math on"
|
#error "Can't be compiled with fast math on"
|
||||||
@ -211,15 +213,54 @@ typedef Exhaustive32BitOrLessUnaryTest<F32> ExhaustiveF32UnaryTest;
|
|||||||
typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
|
typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
|
||||||
typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
|
typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
|
||||||
|
|
||||||
#define XLA_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \
|
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
|
||||||
|
#define NEED_UNARY_F32 true
|
||||||
|
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||||
|
#define NEED_UNARY_F16 true
|
||||||
|
#else
|
||||||
|
#define NEED_UNARY_F16 false
|
||||||
|
#endif
|
||||||
|
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||||
|
#define NEED_UNARY_BF16 true
|
||||||
|
#else
|
||||||
|
#define NEED_UNARY_BF16 false
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#define NEED_UNARY_F32 false
|
||||||
|
#define NEED_UNARY_F16 false
|
||||||
|
#define NEED_UNARY_BF16 false
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if NEED_UNARY_F32
|
||||||
|
#define UNARY_TEST_F32(test_name, ...) \
|
||||||
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \
|
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_F32(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if NEED_UNARY_F16
|
||||||
|
#define UNARY_TEST_F16(test_name, ...) \
|
||||||
XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \
|
XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_F16(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if NEED_UNARY_BF16
|
||||||
|
#define UNARY_TEST_BF16(test_name, ...) \
|
||||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \
|
XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \
|
||||||
__VA_ARGS__
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_BF16(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, {
|
#define UNARY_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \
|
||||||
|
UNARY_TEST_F32(test_name, __VA_ARGS__) \
|
||||||
|
UNARY_TEST_F16(test_name, __VA_ARGS__) \
|
||||||
|
UNARY_TEST_BF16(test_name, __VA_ARGS__)
|
||||||
|
|
||||||
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
|
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
|
||||||
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
|
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
|
||||||
@ -227,7 +268,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, {
|
|||||||
Run(Log, std::log, error_spec_gen);
|
Run(Log, std::log, error_spec_gen);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
|
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
|
||||||
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
|
error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; };
|
||||||
@ -235,7 +276,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, {
|
|||||||
Run(Log1p, std::log1p, error_spec_gen);
|
Run(Log1p, std::log1p, error_spec_gen);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
|
||||||
// When x < -105, the true value of exp(x) is smaller than the smallest F32,
|
// When x < -105, the true value of exp(x) is smaller than the smallest F32,
|
||||||
// so exp(x) should return exactly 0. We want our implementation of exp to
|
// so exp(x) should return exactly 0. We want our implementation of exp to
|
||||||
// return exactly 0 as well, as not doing so implies either that our
|
// return exactly 0 as well, as not doing so implies either that our
|
||||||
@ -266,7 +307,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (ty_ == F32) {
|
if (ty_ == F32) {
|
||||||
error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
|
error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
|
||||||
@ -292,7 +333,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
|
|||||||
// It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
|
// It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
|
||||||
// this *did* find a bug, namely that some backends were assuming sqrt(x) ==
|
// this *did* find a bug, namely that some backends were assuming sqrt(x) ==
|
||||||
// pow(x, 0.5), but this is not true for x == -inf.
|
// pow(x, 0.5), but this is not true for x == -inf.
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
|
||||||
EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); };
|
EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); };
|
||||||
// TODO(b/123837116): Enable the test for all values after fixing the bug.
|
// TODO(b/123837116): Enable the test for all values after fixing the bug.
|
||||||
if (platform_ != "Host" && platform_ != "CUDA") {
|
if (platform_ != "Host" && platform_ != "CUDA") {
|
||||||
@ -306,12 +347,12 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, {
|
|||||||
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn);
|
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, {
|
||||||
Run(
|
Run(
|
||||||
Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
|
Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ == "Host" || platform_ == "CUDA") {
|
if (platform_ == "Host" || platform_ == "CUDA") {
|
||||||
error_spec_gen = +[](NativeT x) {
|
error_spec_gen = +[](NativeT x) {
|
||||||
@ -349,11 +390,11 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
|
|||||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
||||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); })
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
|
||||||
// Our cosh implementation incorrectly overflows to inf for +/-89.4159851.
|
// Our cosh implementation incorrectly overflows to inf for +/-89.4159851.
|
||||||
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
|
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
|
||||||
// max-float, so we deem this acceptable.
|
// max-float, so we deem this acceptable.
|
||||||
@ -374,7 +415,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, {
|
|||||||
Run(Cosh, host_cosh);
|
Run(Cosh, host_cosh);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
|
||||||
// Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851.
|
// Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851.
|
||||||
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
|
// The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to
|
||||||
// max-float, so we deem this acceptable.
|
// max-float, so we deem this acceptable.
|
||||||
@ -395,7 +436,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, {
|
|||||||
Run(Sinh, host_sinh);
|
Run(Sinh, host_sinh);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Tanh, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Tanh, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ == "CUDA") {
|
if (platform_ == "CUDA") {
|
||||||
error_spec_gen = +[](NativeT x) {
|
error_spec_gen = +[](NativeT x) {
|
||||||
@ -429,62 +470,68 @@ void Exhaustive32BitOrLessUnaryTest<T>::SetParamsForSinCosTan() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32UnaryTest, Cos) {
|
UNARY_TEST_F32(Cos, {
|
||||||
SetParamsForSinCosTan();
|
SetParamsForSinCosTan();
|
||||||
Run(
|
Run(
|
||||||
Cos, std::cos, +[](NativeT) {
|
Cos, std::cos, +[](NativeT) {
|
||||||
return ErrorSpec{0.001, 0.001};
|
return ErrorSpec{0.001, 0.001};
|
||||||
});
|
});
|
||||||
}
|
})
|
||||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Cos) {
|
|
||||||
SetParamsForSinCosTan();
|
|
||||||
Run(Cos, std::cos);
|
|
||||||
}
|
|
||||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Cos) {
|
|
||||||
SetParamsForSinCosTan();
|
|
||||||
Run(Cos, std::cos);
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32UnaryTest, Sin) {
|
UNARY_TEST_F16(Cos, {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Cos, std::cos);
|
||||||
|
})
|
||||||
|
|
||||||
|
UNARY_TEST_BF16(Cos, {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Cos, std::cos);
|
||||||
|
})
|
||||||
|
|
||||||
|
UNARY_TEST_F32(Sin, {
|
||||||
SetParamsForSinCosTan();
|
SetParamsForSinCosTan();
|
||||||
Run(
|
Run(
|
||||||
Sin, std::sin, +[](NativeT) {
|
Sin, std::sin, +[](NativeT) {
|
||||||
return ErrorSpec{0.001, 0.001};
|
return ErrorSpec{0.001, 0.001};
|
||||||
});
|
});
|
||||||
}
|
})
|
||||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Sin) {
|
|
||||||
SetParamsForSinCosTan();
|
|
||||||
Run(Sin, std::sin);
|
|
||||||
}
|
|
||||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Sin) {
|
|
||||||
SetParamsForSinCosTan();
|
|
||||||
Run(Sin, std::sin);
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF32UnaryTest, Tan) {
|
UNARY_TEST_F16(Sin, {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Sin, std::sin);
|
||||||
|
})
|
||||||
|
|
||||||
|
UNARY_TEST_BF16(Sin, {
|
||||||
|
SetParamsForSinCosTan();
|
||||||
|
Run(Sin, std::sin);
|
||||||
|
})
|
||||||
|
|
||||||
|
UNARY_TEST_F32(Tan, {
|
||||||
SetParamsForSinCosTan();
|
SetParamsForSinCosTan();
|
||||||
Run(
|
Run(
|
||||||
Tan, std::tan, +[](NativeT) {
|
Tan, std::tan, +[](NativeT) {
|
||||||
return ErrorSpec{0.001, 0.001};
|
return ErrorSpec{0.001, 0.001};
|
||||||
});
|
});
|
||||||
}
|
})
|
||||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Tan) {
|
|
||||||
|
UNARY_TEST_F16(Tan, {
|
||||||
SetParamsForSinCosTan();
|
SetParamsForSinCosTan();
|
||||||
Run(Tan, std::tan);
|
Run(Tan, std::tan);
|
||||||
}
|
})
|
||||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Tan) {
|
|
||||||
|
UNARY_TEST_BF16(Tan, {
|
||||||
SetParamsForSinCosTan();
|
SetParamsForSinCosTan();
|
||||||
Run(Tan, std::tan);
|
Run(Tan, std::tan);
|
||||||
}
|
})
|
||||||
|
|
||||||
// TODO(jlebar): Enable these.
|
// TODO(jlebar): Enable these.
|
||||||
// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); }
|
// UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); }
|
||||||
// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); }
|
// UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); }
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); })
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); })
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); })
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ != "Host" && platform_ != "CUDA") {
|
if (platform_ != "Host" && platform_ != "CUDA") {
|
||||||
// TODO(b/123956399): This is a fairly high error, significantly higher than
|
// TODO(b/123956399): This is a fairly high error, significantly higher than
|
||||||
@ -514,7 +561,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
|
||||||
// Our implementation gets within 0.0001 rel error except for ~20 denormal
|
// Our implementation gets within 0.0001 rel error except for ~20 denormal
|
||||||
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
|
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
@ -545,9 +592,7 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
|
|||||||
Run(Lgamma, host_lgamma, error_spec_gen);
|
Run(Lgamma, host_lgamma, error_spec_gen);
|
||||||
})
|
})
|
||||||
|
|
||||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
|
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
|
INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
|
||||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||||
@ -562,8 +607,6 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest,
|
|||||||
::testing::Values(std::make_pair(0, 1 << 16)));
|
::testing::Values(std::make_pair(0, 1 << 16)));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Exhaustive test for unary operations for double.
|
// Exhaustive test for unary operations for double.
|
||||||
//
|
//
|
||||||
// Test parameter is a tuple containing
|
// Test parameter is a tuple containing
|
||||||
@ -594,42 +637,51 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); }
|
#if defined(UNARY_TEST_TARGET_F64) && \
|
||||||
|
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||||
|
#define UNARY_TEST_FLOAT_64(test_name, ...) \
|
||||||
|
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
|
||||||
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_FLOAT_64(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Log1p) { Run(Log1p, std::log1p); }
|
UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Exp) { Run(Exp, std::exp); }
|
UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Expm1) { Run(Expm1, std::expm1); }
|
UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
|
||||||
|
|
||||||
|
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
|
||||||
|
|
||||||
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
|
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, DISABLED_ON_GPU(PowOneHalf)) {
|
UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
|
||||||
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
|
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
|
||||||
+[](double x) { return std::pow(x, 0.5); });
|
+[](double x) { return std::pow(x, 0.5); });
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Rsqrt) {
|
UNARY_TEST_FLOAT_64(Rsqrt, {
|
||||||
Run(
|
Run(
|
||||||
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
|
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Sqrt) { Run(Sqrt, std::sqrt); }
|
UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Acosh) { Run(Acosh, std::acosh); }
|
UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Atanh) { Run(Atanh, std::atanh); }
|
UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Acos) { Run(Acos, std::acos); }
|
UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Asin) { Run(Asin, std::asin); }
|
UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Cosh) { Run(Cosh, std::cosh); }
|
UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Sinh) { Run(Sinh, std::sinh); }
|
UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) {
|
UNARY_TEST_FLOAT_64(Tanh, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ == "CUDA") {
|
if (platform_ == "CUDA") {
|
||||||
error_spec_gen = +[](NativeT x) {
|
error_spec_gen = +[](NativeT x) {
|
||||||
@ -639,26 +691,24 @@ XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
Run(Tanh, std::tanh, error_spec_gen);
|
Run(Tanh, std::tanh, error_spec_gen);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Cos) { Run(Cos, std::cos); }
|
UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Sin) { Run(Sin, std::sin); }
|
UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Tan) { Run(Tan, std::tan); }
|
UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Round) { Run(Round, std::round); }
|
UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Erf) {
|
UNARY_TEST_FLOAT_64(Erf, {
|
||||||
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveF64UnaryTest, Erfc) {
|
UNARY_TEST_FLOAT_64(Erfc, {
|
||||||
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||||
}
|
})
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_F64)
|
|
||||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SpecialValues, ExhaustiveF64UnaryTest,
|
SpecialValues, ExhaustiveF64UnaryTest,
|
||||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
|
||||||
@ -672,8 +722,6 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest,
|
LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest,
|
||||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
|
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
|
||||||
4000000000ull, 16000000)));
|
4000000000ull, 16000000)));
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// T is the Primitive Type of the complex number
|
// T is the Primitive Type of the complex number
|
||||||
// Test parameter is a tuple containing
|
// Test parameter is a tuple containing
|
||||||
@ -741,30 +789,38 @@ class ExhaustiveComplexUnaryTestBase
|
|||||||
typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
|
typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
|
||||||
typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
|
typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
|
||||||
|
|
||||||
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||||
XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) {
|
#define UNARY_TEST_COMPLEX_64(test_name, ...) \
|
||||||
Run(Log, [](complex64 x) { return std::log<float>(x); });
|
XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
|
||||||
}
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_COMPLEX_64(test_name, ...)
|
||||||
|
#endif
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) {
|
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
||||||
|
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
|
||||||
|
Run(Log, [](complex64 x) { return std::log<float>(x); });
|
||||||
|
})
|
||||||
|
|
||||||
|
UNARY_TEST_COMPLEX_64(Sqrt, {
|
||||||
Run(Sqrt, [](complex64 x) {
|
Run(Sqrt, [](complex64 x) {
|
||||||
return static_cast<complex64>(
|
return static_cast<complex64>(
|
||||||
std::sqrt<double>(static_cast<complex128>(x)));
|
std::sqrt<double>(static_cast<complex128>(x)));
|
||||||
});
|
});
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) {
|
UNARY_TEST_COMPLEX_64(Rsqrt, {
|
||||||
Run(Rsqrt, [](complex64 x) {
|
Run(Rsqrt, [](complex64 x) {
|
||||||
return static_cast<complex64>(
|
return static_cast<complex64>(
|
||||||
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
|
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
|
||||||
});
|
});
|
||||||
}
|
})
|
||||||
|
|
||||||
// The current libc++ implementation of the complex tanh function provides
|
// The current libc++ implementation of the complex tanh function provides
|
||||||
// less accurate results when the denomenator of a complex tanh is small, due
|
// less accurate results when the denomenator of a complex tanh is small, due
|
||||||
// to floating point precision loss. To avoid this issue for complex64 numbers,
|
// to floating point precision loss. To avoid this issue for complex64 numbers,
|
||||||
// we cast it to and from a complex128 when computing tanh.
|
// we cast it to and from a complex128 when computing tanh.
|
||||||
XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) {
|
UNARY_TEST_COMPLEX_64(Tanh, {
|
||||||
SetParamsForTanh();
|
SetParamsForTanh();
|
||||||
ErrorSpecGen error_spec_gen = +[](complex64 x) {
|
ErrorSpecGen error_spec_gen = +[](complex64 x) {
|
||||||
// This implementation of Tanh becomes less accurate when the denominator
|
// This implementation of Tanh becomes less accurate when the denominator
|
||||||
@ -781,9 +837,8 @@ XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) {
|
|||||||
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
|
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
|
||||||
},
|
},
|
||||||
error_spec_gen);
|
error_spec_gen);
|
||||||
}
|
})
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
F32SpecialValues, ExhaustiveC64UnaryTest,
|
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
@ -816,10 +871,17 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
4000)),
|
4000)),
|
||||||
::testing::ValuesIn(
|
::testing::ValuesIn(
|
||||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
|
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
|
||||||
|
|
||||||
|
#if defined(UNARY_TEST_TARGET_COMPLEX) && \
|
||||||
|
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||||
|
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
|
||||||
|
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
|
||||||
|
__VA_ARGS__
|
||||||
|
#else
|
||||||
|
#define UNARY_TEST_COMPLEX_128(test_name, ...)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
UNARY_TEST_COMPLEX_128(Log, {
|
||||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
|
||||||
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
||||||
known_incorrect_fn_ = [&](int64 v) {
|
known_incorrect_fn_ = [&](int64 v) {
|
||||||
double f = this->ConvertValue(v);
|
double f = this->ConvertValue(v);
|
||||||
@ -827,18 +889,18 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
|||||||
std::abs(f) < 1.0e-300;
|
std::abs(f) < 1.0e-300;
|
||||||
};
|
};
|
||||||
Run(Log, [](complex128 x) { return std::log<double>(x); });
|
Run(Log, [](complex128 x) { return std::log<double>(x); });
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) {
|
UNARY_TEST_COMPLEX_128(Sqrt, {
|
||||||
// Similar to the Tanh bug.
|
// Similar to the Tanh bug.
|
||||||
known_incorrect_fn_ = [&](int64 v) {
|
known_incorrect_fn_ = [&](int64 v) {
|
||||||
double f = this->ConvertValue(v);
|
double f = this->ConvertValue(v);
|
||||||
return std::abs(f) > std::numeric_limits<double>::max() / 2;
|
return std::abs(f) > std::numeric_limits<double>::max() / 2;
|
||||||
};
|
};
|
||||||
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
|
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) {
|
UNARY_TEST_COMPLEX_128(Rsqrt, {
|
||||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||||
if (platform_ == "CUDA") {
|
if (platform_ == "CUDA") {
|
||||||
// Edge case on CUDA backend where the Log of a complex number made up of
|
// Edge case on CUDA backend where the Log of a complex number made up of
|
||||||
@ -856,16 +918,14 @@ XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) {
|
|||||||
Rsqrt,
|
Rsqrt,
|
||||||
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
|
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
|
||||||
error_spec_gen);
|
error_spec_gen);
|
||||||
}
|
})
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) {
|
UNARY_TEST_COMPLEX_128(Tanh, {
|
||||||
SetParamsForTanh();
|
SetParamsForTanh();
|
||||||
Run(
|
Run(
|
||||||
Tanh, +[](complex128 x) { return std::tanh(x); });
|
Tanh, +[](complex128 x) { return std::tanh(x); });
|
||||||
}
|
})
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
|
||||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SpecialValues, ExhaustiveC128UnaryTest,
|
SpecialValues, ExhaustiveC128UnaryTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
@ -898,7 +958,5 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||||
::testing::ValuesIn(
|
::testing::ValuesIn(
|
||||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -259,17 +259,31 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
|
|||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
|
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run<float>(GetParam()); }
|
// TODO(b/69425338): The following tests are disable on GPU because they use
|
||||||
|
// too much GPU memory.
|
||||||
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F32)) {
|
||||||
|
Run<float>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run<double>(GetParam()); }
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_F64)) {
|
||||||
|
Run<double>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run<uint32>(GetParam()); }
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U32)) {
|
||||||
|
Run<uint32>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run<int32>(GetParam()); }
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S32)) {
|
||||||
|
Run<int32>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run<uint64>(GetParam()); }
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_U64)) {
|
||||||
|
Run<uint64>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run<int64>(GetParam()); }
|
XLA_TEST_P(SliceR1LargeTest, DISABLED_ON_GPU(DoIt_S64)) {
|
||||||
|
Run<int64>(GetParam());
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
|
XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
|
||||||
|
|
||||||
@ -315,8 +329,6 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
SliceR1TestDataToString
|
SliceR1TestDataToString
|
||||||
);
|
);
|
||||||
|
|
||||||
// TODO(b/69425338): This uses too much memory on GPU.
|
|
||||||
#ifndef XLA_TEST_BACKEND_GPU
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
SliceR1TestBigSlicesInstantiation,
|
SliceR1TestBigSlicesInstantiation,
|
||||||
SliceR1LargeTest,
|
SliceR1LargeTest,
|
||||||
@ -330,7 +342,6 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
),
|
),
|
||||||
SliceR1TestDataToString
|
SliceR1TestDataToString
|
||||||
);
|
);
|
||||||
#endif
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
SliceStridedR1TestInstantiation,
|
SliceStridedR1TestInstantiation,
|
||||||
|
Loading…
Reference in New Issue
Block a user