Split up exhaustive_unary_test.cc into several files.
We have three different test targets, each needing a different section of the currently existing source file. There is no reason not to split this up into three files. This also allows us to get rid of #define flags, and we can fix the issue that required usage of GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST macro. PiperOrigin-RevId: 293566054 Change-Id: Iad1c8ab39a629f3cf4f4327933317473f11209d1
This commit is contained in:
parent
af49800eda
commit
fc1961d9c1
@ -777,8 +777,7 @@ cc_library(
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_unary_test_f32_or_smaller",
|
||||
srcs = ["exhaustive_unary_test.cc"],
|
||||
copts = ["-DUNARY_TEST_TARGET_F32_OR_SMALLER"],
|
||||
srcs = ["exhaustive_unary_test_f32_or_smaller.cc"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
@ -796,13 +795,11 @@ xla_test(
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_unary_test_f64",
|
||||
srcs = ["exhaustive_unary_test.cc"],
|
||||
srcs = ["exhaustive_unary_test_f64.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DUNARY_TEST_TARGET_F64"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
@ -819,13 +816,11 @@ xla_test(
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_unary_test_complex",
|
||||
srcs = ["exhaustive_unary_test.cc"],
|
||||
srcs = ["exhaustive_unary_test_complex.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
|
@ -1004,6 +1004,15 @@ typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
|
||||
return DefaultSpecGenerator<T, N>;
|
||||
}
|
||||
|
||||
template <PrimitiveType T>
|
||||
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
|
||||
public:
|
||||
using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
|
||||
static ErrorSpecGen GetDefaultSpecGenerator() {
|
||||
return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
|
||||
|
256
tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc
Normal file
256
tensorflow/compiler/xla/tests/exhaustive_unary_test_complex.cc
Normal file
@ -0,0 +1,256 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
#ifdef __FAST_MATH__
|
||||
#error "Can't be compiled with fast math on"
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace exhaustive_op_test {
|
||||
|
||||
// T is the Primitive Type of the complex number
|
||||
// Test parameter is a tuple containing
|
||||
// - primitive type under test,
|
||||
// - two FpValues representing the values for the real and imaginary
|
||||
// components. The complex numbers for the test input is the cartesian
|
||||
// product of the values represented by the two FpValues.
|
||||
template <PrimitiveType T>
|
||||
class ExhaustiveComplexUnaryTestBase
|
||||
: public ExhaustiveUnaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
|
||||
protected:
|
||||
using typename ExhaustiveUnaryTest<T>::NativeT;
|
||||
|
||||
void SetParamsForTanh() {
|
||||
// TODO(b/138126045): Current libc++ implementation of the complex tanh
|
||||
// function returns (NaN, NaN) when the imaginary
|
||||
// component is more than half of the max value.
|
||||
// TODO(b/138750327): Current libc++ implementation of the complex tanh
|
||||
// function returns (1, 0) when the real component is
|
||||
// negative infinity, when it should return (-1, 0).
|
||||
// We only need to set the former as incorrect values for C128 because when
|
||||
// testing with C64, we first cast our input to a C128 value.
|
||||
this->known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return (T == C128 &&
|
||||
std::abs(f) > std::numeric_limits<double>::max() / 2) ||
|
||||
f == -std::numeric_limits<double>::infinity();
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
// Generates the input complex literal given the FpValues representation for
|
||||
// the real and imaginary components.
|
||||
void FillInput(std::array<Literal, 1>* input_literal) override {
|
||||
FpValues real_values = std::get<0>(GetParam());
|
||||
FpValues imag_values = std::get<1>(GetParam());
|
||||
|
||||
VLOG(2) << " testing input total "
|
||||
<< real_values.GetTotalNumValues() * imag_values.GetTotalNumValues()
|
||||
<< ", range " << real_values.ToString() << " "
|
||||
<< imag_values.ToString();
|
||||
|
||||
absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
|
||||
|
||||
uint64 i = 0;
|
||||
for (auto real : real_values) {
|
||||
for (auto imag : imag_values) {
|
||||
input_arr[i] =
|
||||
NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1),
|
||||
this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1));
|
||||
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64 GetInputSize() override {
|
||||
FpValues real_values = std::get<0>(GetParam());
|
||||
FpValues imag_values = std::get<1>(GetParam());
|
||||
return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues();
|
||||
}
|
||||
};
|
||||
|
||||
using ExhaustiveC64UnaryTest = ExhaustiveComplexUnaryTestBase<C64>;
|
||||
|
||||
using ExhaustiveC128UnaryTest = ExhaustiveComplexUnaryTestBase<C128>;
|
||||
|
||||
#define UNARY_TEST_COMPLEX_64(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
|
||||
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
||||
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
|
||||
Run(Log, [](complex64 x) { return std::log<float>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_64(Sqrt, {
|
||||
Run(Sqrt, [](complex64 x) {
|
||||
return static_cast<complex64>(
|
||||
std::sqrt<double>(static_cast<complex128>(x)));
|
||||
});
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_64(Rsqrt, {
|
||||
Run(Rsqrt, [](complex64 x) {
|
||||
return static_cast<complex64>(
|
||||
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
|
||||
});
|
||||
})
|
||||
|
||||
// The current libc++ implementation of the complex tanh function provides
|
||||
// less accurate results when the denomenator of a complex tanh is small, due
|
||||
// to floating point precision loss. To avoid this issue for complex64 numbers,
|
||||
// we cast it to and from a complex128 when computing tanh.
|
||||
UNARY_TEST_COMPLEX_64(Tanh, {
|
||||
SetParamsForTanh();
|
||||
ErrorSpecGen error_spec_gen = +[](complex64 x) {
|
||||
// This implementation of Tanh becomes less accurate when the denominator
|
||||
// is small.
|
||||
if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) {
|
||||
return ErrorSpec{5e-2, 5e-2};
|
||||
}
|
||||
|
||||
return GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
Run(
|
||||
Tanh,
|
||||
+[](complex64 x) {
|
||||
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
|
||||
},
|
||||
error_spec_gen);
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::Values(GetNormals<float>(10000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<float>(10000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<float>(10000)),
|
||||
::testing::Values(GetNormals<float>(10000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32LargeAndSmallMagnitudeNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
|
||||
4000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
|
||||
|
||||
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Log, {
|
||||
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
||||
known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
|
||||
std::abs(f) < 1.0e-300;
|
||||
};
|
||||
Run(Log, [](complex128 x) { return std::log<double>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Sqrt, {
|
||||
// Similar to the Tanh bug.
|
||||
known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return std::abs(f) > std::numeric_limits<double>::max() / 2;
|
||||
};
|
||||
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Rsqrt, {
|
||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||
if (platform_ == "CUDA") {
|
||||
// Edge case on CUDA backend where the Log of a complex number made up of
|
||||
// the smallest denormals is more accurate than the interpreter backend.
|
||||
error_spec_gen = [](complex128 x) {
|
||||
constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
|
||||
if (std::abs(x.real()) == denorm_min &&
|
||||
std::abs(x.imag()) == denorm_min) {
|
||||
return ErrorSpec(0.5, 0.5);
|
||||
}
|
||||
return GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
}
|
||||
Run(
|
||||
Rsqrt,
|
||||
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
|
||||
error_spec_gen);
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Tanh, {
|
||||
SetParamsForTanh();
|
||||
Run(
|
||||
Tanh, +[](complex128 x) { return std::tanh(x); });
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialAndNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::Values(GetNormals<double>(10000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndSpecialValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<double>(10000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<double>(10000)),
|
||||
::testing::Values(GetNormals<double>(10000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnitudeNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
@ -158,15 +158,6 @@ float HostDigamma(float x) {
|
||||
return result - reflection;
|
||||
}
|
||||
|
||||
template <PrimitiveType T>
|
||||
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
|
||||
public:
|
||||
using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
|
||||
static ErrorSpecGen GetDefaultSpecGenerator() {
|
||||
return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
|
||||
}
|
||||
};
|
||||
|
||||
// Exhaustive test for unary operations for <= 32bit floating point types.
|
||||
//
|
||||
// Test parameter is a tuple containing
|
||||
@ -216,15 +207,10 @@ class Exhaustive32BitOrLessUnaryTest
|
||||
}
|
||||
};
|
||||
|
||||
typedef Exhaustive32BitOrLessUnaryTest<F32> ExhaustiveF32UnaryTest;
|
||||
typedef Exhaustive32BitOrLessUnaryTest<F16> ExhaustiveF16UnaryTest;
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveF16UnaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest<F32>;
|
||||
using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest<F16>;
|
||||
using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest<BF16>;
|
||||
|
||||
typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
|
||||
#define NEED_UNARY_F32 true
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
#define NEED_UNARY_F16 true
|
||||
#else
|
||||
@ -235,19 +221,10 @@ typedef Exhaustive32BitOrLessUnaryTest<BF16> ExhaustiveBF16UnaryTest;
|
||||
#else
|
||||
#define NEED_UNARY_BF16 false
|
||||
#endif
|
||||
#else
|
||||
#define NEED_UNARY_F32 false
|
||||
#define NEED_UNARY_F16 false
|
||||
#define NEED_UNARY_BF16 false
|
||||
#endif
|
||||
|
||||
#if NEED_UNARY_F32
|
||||
#define UNARY_TEST_F32(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define UNARY_TEST_F32(test_name, ...)
|
||||
#endif
|
||||
|
||||
#if NEED_UNARY_F16
|
||||
#define UNARY_TEST_F16(test_name, ...) \
|
||||
@ -385,8 +362,14 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) {
|
||||
|
||||
Run(Acosh, std::acosh, error_spec_gen);
|
||||
}
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
|
||||
#endif
|
||||
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Acosh) { Run(Acosh, std::acosh); }
|
||||
#endif
|
||||
|
||||
// Tests for Asinh
|
||||
XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
|
||||
@ -397,8 +380,14 @@ XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) {
|
||||
|
||||
Run(Asinh, std::asinh, error_spec_gen);
|
||||
}
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
||||
#endif
|
||||
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); }
|
||||
#endif
|
||||
|
||||
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); })
|
||||
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); })
|
||||
@ -625,364 +614,5 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest,
|
||||
::testing::Values(std::make_pair(0, 1 << 16)));
|
||||
#endif
|
||||
|
||||
// Exhaustive test for unary operations for double.
|
||||
//
|
||||
// Test parameter is a tuple containing
|
||||
// - primitive type under test,
|
||||
// - FpValues representing a set of double values.
|
||||
|
||||
class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
|
||||
public ::testing::WithParamInterface<FpValues> {
|
||||
private:
|
||||
int64 GetInputSize() override {
|
||||
FpValues values = GetParam();
|
||||
return values.GetTotalNumValues();
|
||||
}
|
||||
|
||||
void FillInput(std::array<Literal, 1>* input_literal) override {
|
||||
FpValues fp_values = GetParam();
|
||||
int64 input_size = (*input_literal)[0].element_count();
|
||||
LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
|
||||
<< input_size;
|
||||
absl::Span<double> input_arr = (*input_literal)[0].data<double>();
|
||||
|
||||
uint64 i = 0;
|
||||
for (auto bits : fp_values) {
|
||||
input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1);
|
||||
++i;
|
||||
}
|
||||
CHECK_EQ(i, input_size);
|
||||
}
|
||||
};
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveF64UnaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_F64) && \
|
||||
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
#define UNARY_TEST_FLOAT_64(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define UNARY_TEST_FLOAT_64(test_name, ...)
|
||||
#endif
|
||||
|
||||
UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
|
||||
|
||||
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
|
||||
UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
|
||||
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
|
||||
+[](double x) { return std::pow(x, 0.5); });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Rsqrt, {
|
||||
Run(
|
||||
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Tanh, {
|
||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||
if (platform_ == "CUDA") {
|
||||
error_spec_gen = +[](NativeT x) {
|
||||
return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
|
||||
? ErrorSpec{0, 0}
|
||||
: GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
}
|
||||
Run(Tanh, std::tanh, error_spec_gen);
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Erf, {
|
||||
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Erfc, {
|
||||
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveF64UnaryTest,
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest,
|
||||
::testing::Values(GetNormals<double>(1000)));
|
||||
|
||||
// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest,
|
||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
|
||||
4000000000ull, 16000000)));
|
||||
|
||||
// T is the Primitive Type of the complex number
|
||||
// Test parameter is a tuple containing
|
||||
// - primitive type under test,
|
||||
// - two FpValues representing the values for the real and imaginary
|
||||
// components. The complex numbers for the test input is the cartesian
|
||||
// product of the values represented by the two FpValues.
|
||||
template <PrimitiveType T>
|
||||
class ExhaustiveComplexUnaryTestBase
|
||||
: public ExhaustiveUnaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::tuple<FpValues, FpValues>> {
|
||||
protected:
|
||||
using typename ExhaustiveUnaryTest<T>::NativeT;
|
||||
|
||||
void SetParamsForTanh() {
|
||||
// TODO(b/138126045): Current libc++ implementation of the complex tanh
|
||||
// function returns (NaN, NaN) when the imaginary
|
||||
// component is more than half of the max value.
|
||||
// TODO(b/138750327): Current libc++ implementation of the complex tanh
|
||||
// function returns (1, 0) when the real component is
|
||||
// negative infinity, when it should return (-1, 0).
|
||||
// We only need to set the former as incorrect values for C128 because when
|
||||
// testing with C64, we first cast our input to a C128 value.
|
||||
this->known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return (T == C128 &&
|
||||
std::abs(f) > std::numeric_limits<double>::max() / 2) ||
|
||||
f == -std::numeric_limits<double>::infinity();
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
// Generates the input complex literal given the FpValues representation for
|
||||
// the real and imaginary components.
|
||||
void FillInput(std::array<Literal, 1>* input_literal) override {
|
||||
FpValues real_values = std::get<0>(GetParam());
|
||||
FpValues imag_values = std::get<1>(GetParam());
|
||||
|
||||
VLOG(2) << " testing input total "
|
||||
<< real_values.GetTotalNumValues() * imag_values.GetTotalNumValues()
|
||||
<< ", range " << real_values.ToString() << " "
|
||||
<< imag_values.ToString();
|
||||
|
||||
absl::Span<NativeT> input_arr = (*input_literal)[0].data<NativeT>();
|
||||
|
||||
uint64 i = 0;
|
||||
for (auto real : real_values) {
|
||||
for (auto imag : imag_values) {
|
||||
input_arr[i] =
|
||||
NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1),
|
||||
this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1));
|
||||
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64 GetInputSize() override {
|
||||
FpValues real_values = std::get<0>(GetParam());
|
||||
FpValues imag_values = std::get<1>(GetParam());
|
||||
return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues();
|
||||
}
|
||||
};
|
||||
|
||||
typedef ExhaustiveComplexUnaryTestBase<C64> ExhaustiveC64UnaryTest;
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveC64UnaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
|
||||
typedef ExhaustiveComplexUnaryTestBase<C128> ExhaustiveC128UnaryTest;
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveC128UnaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||
#define UNARY_TEST_COMPLEX_64(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveC64UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define UNARY_TEST_COMPLEX_64(test_name, ...)
|
||||
#endif
|
||||
|
||||
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
||||
UNARY_TEST_COMPLEX_64(DISABLED_ON_CPU(Log), {
|
||||
Run(Log, [](complex64 x) { return std::log<float>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_64(Sqrt, {
|
||||
Run(Sqrt, [](complex64 x) {
|
||||
return static_cast<complex64>(
|
||||
std::sqrt<double>(static_cast<complex128>(x)));
|
||||
});
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_64(Rsqrt, {
|
||||
Run(Rsqrt, [](complex64 x) {
|
||||
return static_cast<complex64>(
|
||||
complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
|
||||
});
|
||||
})
|
||||
|
||||
// The current libc++ implementation of the complex tanh function provides
|
||||
// less accurate results when the denomenator of a complex tanh is small, due
|
||||
// to floating point precision loss. To avoid this issue for complex64 numbers,
|
||||
// we cast it to and from a complex128 when computing tanh.
|
||||
UNARY_TEST_COMPLEX_64(Tanh, {
|
||||
SetParamsForTanh();
|
||||
ErrorSpecGen error_spec_gen = +[](complex64 x) {
|
||||
// This implementation of Tanh becomes less accurate when the denominator
|
||||
// is small.
|
||||
if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) {
|
||||
return ErrorSpec{5e-2, 5e-2};
|
||||
}
|
||||
|
||||
return GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
Run(
|
||||
Tanh,
|
||||
+[](complex64 x) {
|
||||
return static_cast<complex64>(std::tanh(static_cast<complex128>(x)));
|
||||
},
|
||||
error_spec_gen);
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::Values(GetNormals<float>(10000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<float>(10000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<float>(10000)),
|
||||
::testing::Values(GetNormals<float>(10000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32LargeAndSmallMagnitudeNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
|
||||
4000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_COMPLEX) && \
|
||||
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
#define UNARY_TEST_COMPLEX_128(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveC128UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define UNARY_TEST_COMPLEX_128(test_name, ...)
|
||||
#endif
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Log, {
|
||||
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
||||
known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
|
||||
std::abs(f) < 1.0e-300;
|
||||
};
|
||||
Run(Log, [](complex128 x) { return std::log<double>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Sqrt, {
|
||||
// Similar to the Tanh bug.
|
||||
known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = this->ConvertValue(v);
|
||||
return std::abs(f) > std::numeric_limits<double>::max() / 2;
|
||||
};
|
||||
Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Rsqrt, {
|
||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||
if (platform_ == "CUDA") {
|
||||
// Edge case on CUDA backend where the Log of a complex number made up of
|
||||
// the smallest denormals is more accurate than the interpreter backend.
|
||||
error_spec_gen = [](complex128 x) {
|
||||
constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
|
||||
if (std::abs(x.real()) == denorm_min &&
|
||||
std::abs(x.imag()) == denorm_min) {
|
||||
return ErrorSpec(0.5, 0.5);
|
||||
}
|
||||
return GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
}
|
||||
Run(
|
||||
Rsqrt,
|
||||
[](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
|
||||
error_spec_gen);
|
||||
})
|
||||
|
||||
UNARY_TEST_COMPLEX_128(Tanh, {
|
||||
SetParamsForTanh();
|
||||
Run(
|
||||
Tanh, +[](complex128 x) { return std::tanh(x); });
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialAndNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||
::testing::Values(GetNormals<double>(10000))));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NormalAndSpecialValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(GetNormals<double>(10000)),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(::testing::Values(GetNormals<double>(10000)),
|
||||
::testing::Values(GetNormals<double>(10000))));
|
||||
|
||||
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnitudeNormalValues, ExhaustiveC128UnaryTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
139
tensorflow/compiler/xla/tests/exhaustive_unary_test_f64.cc
Normal file
139
tensorflow/compiler/xla/tests/exhaustive_unary_test_f64.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
#ifdef __FAST_MATH__
|
||||
#error "Can't be compiled with fast math on"
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace exhaustive_op_test {
|
||||
|
||||
// Exhaustive test for unary operations for double.
|
||||
//
|
||||
// Test parameter is a tuple containing
|
||||
// - primitive type under test,
|
||||
// - FpValues representing a set of double values.
|
||||
|
||||
class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest<F64>,
|
||||
public ::testing::WithParamInterface<FpValues> {
|
||||
private:
|
||||
int64 GetInputSize() override {
|
||||
FpValues values = GetParam();
|
||||
return values.GetTotalNumValues();
|
||||
}
|
||||
|
||||
void FillInput(std::array<Literal, 1>* input_literal) override {
|
||||
FpValues fp_values = GetParam();
|
||||
int64 input_size = (*input_literal)[0].element_count();
|
||||
LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
|
||||
<< input_size;
|
||||
absl::Span<double> input_arr = (*input_literal)[0].data<double>();
|
||||
|
||||
uint64 i = 0;
|
||||
for (auto bits : fp_values) {
|
||||
input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1);
|
||||
++i;
|
||||
}
|
||||
CHECK_EQ(i, input_size);
|
||||
}
|
||||
};
|
||||
|
||||
#define UNARY_TEST_FLOAT_64(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
|
||||
UNARY_TEST_FLOAT_64(Log, { Run(Log, std::log); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Log1p, { Run(Log1p, std::log1p); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Exp, { Run(Exp, std::exp); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Expm1, { Run(Expm1, std::expm1); })
|
||||
|
||||
// TODO(b/138385863): Turn on the test for GPU after fixing the bug.
|
||||
UNARY_TEST_FLOAT_64(DISABLED_ON_GPU(PowOneHalf), {
|
||||
Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
|
||||
+[](double x) { return std::pow(x, 0.5); });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Rsqrt, {
|
||||
Run(
|
||||
Rsqrt, +[](double x) { return 1 / std::sqrt(x); });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sqrt, { Run(Sqrt, std::sqrt); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Acosh, { Run(Acosh, std::acosh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Asinh, { Run(Asinh, std::asinh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Atanh, { Run(Atanh, std::atanh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Acos, { Run(Acos, std::acos); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Asin, { Run(Asin, std::asin); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Cosh, { Run(Cosh, std::cosh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sinh, { Run(Sinh, std::sinh); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Tanh, {
|
||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||
if (platform_ == "CUDA") {
|
||||
error_spec_gen = +[](NativeT x) {
|
||||
return x <= static_cast<NativeT>(-20.0) || x >= static_cast<NativeT>(20.0)
|
||||
? ErrorSpec{0, 0}
|
||||
: GetDefaultSpecGenerator()(x);
|
||||
};
|
||||
}
|
||||
Run(Tanh, std::tanh, error_spec_gen);
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Cos, { Run(Cos, std::cos); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Sin, { Run(Sin, std::sin); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Tan, { Run(Tan, std::tan); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Round, { Run(Round, std::round); })
|
||||
|
||||
UNARY_TEST_FLOAT_64(Erf, {
|
||||
Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_64(Erfc, {
|
||||
Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; });
|
||||
})
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveF64UnaryTest,
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest,
|
||||
::testing::Values(GetNormals<double>(1000)));
|
||||
|
||||
// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
|
||||
// keep the peak memory usage low.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest,
|
||||
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
|
||||
4000000000ull, 16000000)));
|
||||
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user