Split up exhaustive_unary_test.cc into several files.

We have three different test targets, each needing a different section of the
currently existing source file. There is no reason not to split this up into
three files. This also allows us to get rid of #define flags, and we can fix
the issue that required usage of GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST
macro.

PiperOrigin-RevId: 293566054
Change-Id: Iad1c8ab39a629f3cf4f4327933317473f11209d1
This commit is contained in:
Adrian Kuegel 2020-02-06 04:31:09 -08:00 committed by TensorFlower Gardener
parent af49800eda
commit fc1961d9c1
5 changed files with 422 additions and 393 deletions

View File

@ -777,8 +777,7 @@ cc_library(
xla_test(
name = "exhaustive_unary_test_f32_or_smaller",
srcs = ["exhaustive_unary_test.cc"],
copts = ["-DUNARY_TEST_TARGET_F32_OR_SMALLER"],
srcs = ["exhaustive_unary_test_f32_or_smaller.cc"],
real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
@ -796,13 +795,11 @@ xla_test(
xla_test(
name = "exhaustive_unary_test_f64",
srcs = ["exhaustive_unary_test.cc"],
srcs = ["exhaustive_unary_test_f64.cc"],
backends = [
"gpu",
"cpu",
],
copts = ["-DUNARY_TEST_TARGET_F64"],
real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
"optonly",
@ -819,13 +816,11 @@ xla_test(
xla_test(
name = "exhaustive_unary_test_complex",
srcs = ["exhaustive_unary_test.cc"],
srcs = ["exhaustive_unary_test_complex.cc"],
backends = [
"gpu",
"cpu",
],
copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
"optonly",

View File

@ -1004,6 +1004,15 @@ typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
return DefaultSpecGenerator<T, N>;
}
template <PrimitiveType T>
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
public:
using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
static ErrorSpecGen GetDefaultSpecGenerator() {
return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
}
};
} // namespace exhaustive_op_test
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_

View File

@ -0,0 +1,256 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#ifdef __FAST_MATH__
#error "Can't be compiled with fast math on"
#endif
namespace xla {
namespace exhaustive_op_test {
// T is the Primitive Type of the complex number
// Test parameter is a tuple containing
// - primitive type under test,
// - two FpValues representing the values for the real and imaginary
// components. The complex numbers for the test input is the cartesian
// product of the values represented by the two FpValues.
template <PrimitiveType T>
class ExhaustiveComplexUnaryTestBase
: public ExhaustiveUnaryTest<T>,
public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
protected:
using typename ExhaustiveUnaryTest<T>::NativeT;
void SetParamsForTanh() {
// TODO(b/138126045): Current libc++ implementation of the complex tanh
// function returns (NaN, NaN) when the imaginary
// component is more than half of the max value.
// TODO(b/138750327): Current libc++ implementation of the complex tanh
// function returns (1, 0) when the real component is
// negative infinity, when it should return (-1, 0).
// We only need to set the former as incorrect values for C128 because when
// testing with C64, we first cast our input to a C128 value.
this->known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return (T == C128 &&
std::abs(f) > std::numeric_limits<double>::max() / 2) ||
f == -std::numeric_limits<double>::infinity();
};
}
private:
// Generates the input complex literal given the FpValues representation for
// the real and imaginary components.
void FillInput(std::array<Literal, 1>* input_literal) override {
FpValues real_values = std::get<0>(GetParam());
FpValues imag_values = std::get<1>(GetParam());
VLOG(2) << " testing input total "
<< real_values.GetTotalNumValues() * imag_values.GetTotalNumValues()
<< ", range " << real_values.ToString() << " "
<< imag_values.ToString();
absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
uint64 i = 0;
for (auto real : real_values) {
for (auto imag : imag_values) {
input_arr[i] =
NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1),
this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1));
++i;
}
}
}
int64 GetInputSize() override {
FpValues real_values = std::get<0>(GetParam());
FpValues imag_values = std::get<1>(GetParam());
return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues();
}
};
using ExhaustiveC64UnaryTest = ExhaustiveComplexUnaryTestBase<C64>;
using ExhaustiveC128UnaryTest = ExhaustiveComplexUnaryTestBase<C128>;
#define UNARY_TEST_COMPLEX_64(test_name, ...) \
XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
__VA_ARGS__
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
Run(Log, [](complex64 x) { return std::log<float>(x); });
})
UNARY_TEST_COMPLEX_64(Sqrt, {
Run(Sqrt, [](complex64 x) {
return static_cast<complex64>(
std::sqrt<double>(static_cast<complex128>(x)));
});
})
UNARY_TEST_COMPLEX_64(Rsqrt, {
Run(Rsqrt, [](complex64 x) {
return static_cast<complex64>(
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
});
})
// The current libc++ implementation of the complex tanh function provides
// less accurate results when the denomenator of a complex tanh is small, due
// to floating point precision loss. To avoid this issue for complex64 numbers,
// we cast it to and from a complex128 when computing tanh.
UNARY_TEST_COMPLEX_64(Tanh, {
SetParamsForTanh();
ErrorSpecGen error_spec_gen = +[](complex64 x) {
// This implementation of Tanh becomes less accurate when the denominator
// is small.
if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) {
return ErrorSpec{5e-2, 5e-2};
}
return GetDefaultSpecGenerator()(x);
};
Run(
Tanh,
+[](complex64 x) {
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
},
error_spec_gen);
})
INSTANTIATE_TEST_SUITE_P(
F32SpecialValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
INSTANTIATE_TEST_SUITE_P(
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
::testing::Values(GetNormals<float>(10000))));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::Values(GetNormals<float>(10000)),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(::testing::Values(GetNormals<float>(10000)),
::testing::Values(GetNormals<float>(10000))));
// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
F32LargeAndSmallMagnitudeNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
4000)),
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
__VA_ARGS__
UNARY_TEST_COMPLEX_128(Log, {
// TODO(b/138578313): Enable the test for all values after fixing the bug.
known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
std::abs(f) < 1.0e-300;
};
Run(Log, [](complex128 x) { return std::log<double>(x); });
})
UNARY_TEST_COMPLEX_128(Sqrt, {
// Similar to the Tanh bug.
known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return std::abs(f) > std::numeric_limits<double>::max() / 2;
};
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
})
UNARY_TEST_COMPLEX_128(Rsqrt, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") {
// Edge case on CUDA backend where the Log of a complex number made up of
// the smallest denormals is more accurate than the interpreter backend.
error_spec_gen = [](complex128 x) {
constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
if (std::abs(x.real()) == denorm_min &&
std::abs(x.imag()) == denorm_min) {
return ErrorSpec(0.5, 0.5);
}
return GetDefaultSpecGenerator()(x);
};
}
Run(
Rsqrt,
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
error_spec_gen);
})
UNARY_TEST_COMPLEX_128(Tanh, {
SetParamsForTanh();
Run(
Tanh, +[](complex128 x) { return std::tanh(x); });
})
INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
INSTANTIATE_TEST_SUITE_P(
SpecialAndNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
::testing::Values(GetNormals<double>(10000))));
INSTANTIATE_TEST_SUITE_P(
NormalAndSpecialValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::Values(GetNormals<double>(10000)),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(::testing::Values(GetNormals<double>(10000)),
::testing::Values(GetNormals<double>(10000))));
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
LargeAndSmallMagnitudeNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
} // namespace exhaustive_op_test
} // namespace xla

View File

@ -158,15 +158,6 @@ float HostDigamma(float x) {
return result - reflection;
}
template <PrimitiveType T>
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
public:
using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
static ErrorSpecGen GetDefaultSpecGenerator() {
return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
}
};
// Exhaustive test for unary operations for <= 32bit floating point types.
//
// Test parameter is a tuple containing
@ -216,15 +207,10 @@ class Exhaustive32BitOrLessUnaryTest
}
};
typedef Exhaustive32BitOrLessUnaryTest<F32> ExhaustiveF32UnaryTest;
typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
ExhaustiveF16UnaryTest); // TODO(b/139702016) go/are-your-tests-running
using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest<F32>;
using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest<F16>;
using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest<BF16>;
typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
#define NEED_UNARY_F32 true
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
#define NEED_UNARY_F16 true
#else
@ -235,19 +221,10 @@ typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
#else
#define NEED_UNARY_BF16 false
#endif
#else
#define NEED_UNARY_F32 false
#define NEED_UNARY_F16 false
#define NEED_UNARY_BF16 false
#endif
#if NEED_UNARY_F32
#define UNARY_TEST_F32(test_name, ...) \
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_F32(test_name, ...)
#endif
#if NEED_UNARY_F16
#define UNARY_TEST_F16(test_name, ...) \
@ -385,8 +362,14 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) {
Run(Acosh, std::acosh, error_spec_gen);
}
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
XLA_TEST_P(ExhaustiveF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
#endif
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
XLA_TEST_P(ExhaustiveBF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
#endif
// Tests for Asinh
XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
@ -397,8 +380,14 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
Run(Asinh, std::asinh, error_spec_gen);
}
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
#endif
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
#endif
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
@ -625,364 +614,5 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest,
::testing::Values(std::make_pair(0, 1 << 16)));
#endif
// Exhaustive test for unary operations for double.
//
// Test parameter is a tuple containing
// - primitive type under test,
// - FpValues representing a set of double values.
class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
public ::testing::WithParamInterface<FpValues> {
private:
int64 GetInputSize() override {
FpValues values = GetParam();
return values.GetTotalNumValues();
}
void FillInput(std::array<Literal, 1>* input_literal) override {
FpValues fp_values = GetParam();
int64 input_size = (*input_literal)[0].element_count();
LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
<< input_size;
absl::Span<double> input_arr = (*input_literal)[0].data<double>();
uint64 i = 0;
for (auto bits : fp_values) {
input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1);
++i;
}
CHECK_EQ(i, input_size);
}
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
ExhaustiveF64UnaryTest); // TODO(b/139702016) go/are-your-tests-running
#if defined(UNARY_TEST_TARGET_F64) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
#define UNARY_TEST_FLOAT_64(test_name, ...) \
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_FLOAT_64(test_name, ...)
#endif
UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
+[](double x) { return std::pow(x, 0.5); });
})
UNARY_TEST_FLOAT_64(Rsqrt, {
Run(
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
})
UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
UNARY_TEST_FLOAT_64(Tanh, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") {
error_spec_gen = +[](NativeT x) {
return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
? ErrorSpec{0, 0}
: GetDefaultSpecGenerator()(x);
};
}
Run(Tanh, std::tanh, error_spec_gen);
})
UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
UNARY_TEST_FLOAT_64(Erf, {
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
})
UNARY_TEST_FLOAT_64(Erfc, {
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
})
INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest,
::testing::Values(GetNormals<double>(1000)));
// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
4000000000ull, 16000000)));
// T is the Primitive Type of the complex number
// Test parameter is a tuple containing
// - primitive type under test,
// - two FpValues representing the values for the real and imaginary
// components. The complex numbers for the test input is the cartesian
// product of the values represented by the two FpValues.
template <PrimitiveType T>
class ExhaustiveComplexUnaryTestBase
: public ExhaustiveUnaryTest<T>,
public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
protected:
using typename ExhaustiveUnaryTest<T>::NativeT;
void SetParamsForTanh() {
// TODO(b/138126045): Current libc++ implementation of the complex tanh
// function returns (NaN, NaN) when the imaginary
// component is more than half of the max value.
// TODO(b/138750327): Current libc++ implementation of the complex tanh
// function returns (1, 0) when the real component is
// negative infinity, when it should return (-1, 0).
// We only need to set the former as incorrect values for C128 because when
// testing with C64, we first cast our input to a C128 value.
this->known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return (T == C128 &&
std::abs(f) > std::numeric_limits<double>::max() / 2) ||
f == -std::numeric_limits<double>::infinity();
};
}
private:
// Generates the input complex literal given the FpValues representation for
// the real and imaginary components.
void FillInput(std::array<Literal, 1>* input_literal) override {
FpValues real_values = std::get<0>(GetParam());
FpValues imag_values = std::get<1>(GetParam());
VLOG(2) << " testing input total "
<< real_values.GetTotalNumValues() * imag_values.GetTotalNumValues()
<< ", range " << real_values.ToString() << " "
<< imag_values.ToString();
absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
uint64 i = 0;
for (auto real : real_values) {
for (auto imag : imag_values) {
input_arr[i] =
NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1),
this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1));
++i;
}
}
}
int64 GetInputSize() override {
FpValues real_values = std::get<0>(GetParam());
FpValues imag_values = std::get<1>(GetParam());
return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues();
}
};
typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
ExhaustiveC64UnaryTest); // TODO(b/139702016) go/are-your-tests-running
typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
ExhaustiveC128UnaryTest); // TODO(b/139702016) go/are-your-tests-running
#if defined(UNARY_TEST_TARGET_COMPLEX)
#define UNARY_TEST_COMPLEX_64(test_name, ...) \
XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_COMPLEX_64(test_name, ...)
#endif
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
Run(Log, [](complex64 x) { return std::log<float>(x); });
})
UNARY_TEST_COMPLEX_64(Sqrt, {
Run(Sqrt, [](complex64 x) {
return static_cast<complex64>(
std::sqrt<double>(static_cast<complex128>(x)));
});
})
UNARY_TEST_COMPLEX_64(Rsqrt, {
Run(Rsqrt, [](complex64 x) {
return static_cast<complex64>(
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
});
})
// The current libc++ implementation of the complex tanh function provides
// less accurate results when the denomenator of a complex tanh is small, due
// to floating point precision loss. To avoid this issue for complex64 numbers,
// we cast it to and from a complex128 when computing tanh.
UNARY_TEST_COMPLEX_64(Tanh, {
SetParamsForTanh();
ErrorSpecGen error_spec_gen = +[](complex64 x) {
// This implementation of Tanh becomes less accurate when the denominator
// is small.
if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) {
return ErrorSpec{5e-2, 5e-2};
}
return GetDefaultSpecGenerator()(x);
};
Run(
Tanh,
+[](complex64 x) {
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
},
error_spec_gen);
})
INSTANTIATE_TEST_SUITE_P(
F32SpecialValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
INSTANTIATE_TEST_SUITE_P(
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
::testing::Values(GetNormals<float>(10000))));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::Values(GetNormals<float>(10000)),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(::testing::Values(GetNormals<float>(10000)),
::testing::Values(GetNormals<float>(10000))));
// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
F32LargeAndSmallMagnitudeNormalValues, ExhaustiveC64UnaryTest,
::testing::Combine(
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
4000)),
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
#if defined(UNARY_TEST_TARGET_COMPLEX) && \
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
__VA_ARGS__
#else
#define UNARY_TEST_COMPLEX_128(test_name, ...)
#endif
UNARY_TEST_COMPLEX_128(Log, {
// TODO(b/138578313): Enable the test for all values after fixing the bug.
known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
std::abs(f) < 1.0e-300;
};
Run(Log, [](complex128 x) { return std::log<double>(x); });
})
UNARY_TEST_COMPLEX_128(Sqrt, {
// Similar to the Tanh bug.
known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return std::abs(f) > std::numeric_limits<double>::max() / 2;
};
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
})
UNARY_TEST_COMPLEX_128(Rsqrt, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") {
// Edge case on CUDA backend where the Log of a complex number made up of
// the smallest denormals is more accurate than the interpreter backend.
error_spec_gen = [](complex128 x) {
constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
if (std::abs(x.real()) == denorm_min &&
std::abs(x.imag()) == denorm_min) {
return ErrorSpec(0.5, 0.5);
}
return GetDefaultSpecGenerator()(x);
};
}
Run(
Rsqrt,
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
error_spec_gen);
})
UNARY_TEST_COMPLEX_128(Tanh, {
SetParamsForTanh();
Run(
Tanh, +[](complex128 x) { return std::tanh(x); });
})
INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
INSTANTIATE_TEST_SUITE_P(
SpecialAndNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
::testing::Values(GetNormals<double>(10000))));
INSTANTIATE_TEST_SUITE_P(
NormalAndSpecialValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::Values(GetNormals<double>(10000)),
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
INSTANTIATE_TEST_SUITE_P(
F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(::testing::Values(GetNormals<double>(10000)),
::testing::Values(GetNormals<double>(10000))));
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
LargeAndSmallMagnitudeNormalValues, ExhaustiveC128UnaryTest,
::testing::Combine(
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
} // namespace exhaustive_op_test
} // namespace xla

View File

@ -0,0 +1,139 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#ifdef __FAST_MATH__
#error "Can't be compiled with fast math on"
#endif
namespace xla {
namespace exhaustive_op_test {
// Exhaustive test for unary operations for double.
//
// Test parameter is a tuple containing
// - primitive type under test,
// - FpValues representing a set of double values.
class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
public ::testing::WithParamInterface<FpValues> {
private:
int64 GetInputSize() override {
FpValues values = GetParam();
return values.GetTotalNumValues();
}
void FillInput(std::array<Literal, 1>* input_literal) override {
FpValues fp_values = GetParam();
int64 input_size = (*input_literal)[0].element_count();
LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
<< input_size;
absl::Span<double> input_arr = (*input_literal)[0].data<double>();
uint64 i = 0;
for (auto bits : fp_values) {
input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1);
++i;
}
CHECK_EQ(i, input_size);
}
};
#define UNARY_TEST_FLOAT_64(test_name, ...) \
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
__VA_ARGS__
UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
+[](double x) { return std::pow(x, 0.5); });
})
UNARY_TEST_FLOAT_64(Rsqrt, {
Run(
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
})
UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
UNARY_TEST_FLOAT_64(Tanh, {
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
if (platform_ == "CUDA") {
error_spec_gen = +[](NativeT x) {
return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
? ErrorSpec{0, 0}
: GetDefaultSpecGenerator()(x);
};
}
Run(Tanh, std::tanh, error_spec_gen);
})
UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
UNARY_TEST_FLOAT_64(Erf, {
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
})
UNARY_TEST_FLOAT_64(Erfc, {
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
})
INSTANTIATE_TEST_SUITE_P(
SpecialValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest,
::testing::Values(GetNormals<double>(1000)));
// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
// keep the peak memory usage low.
INSTANTIATE_TEST_SUITE_P(
LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest,
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
4000000000ull, 16000000)));
} // namespace exhaustive_op_test
} // namespace xla