[XLA] Add exhaustive test for F16 and BF16 binary operations.
This requires extending the ExhaustiveOpTestBase::run methods to support binary operation and moving CreateExhaustiveF32Ranges from exhaustive_unary_test to exhaustive_op_test_utils. PiperOrigin-RevId: 261786687
This commit is contained in:
parent
0094a421fc
commit
66ba95a89c
@ -789,6 +789,46 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f16",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F16"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_bf16",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_BF16"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "reduce_precision_test",
|
||||
srcs = ["reduce_precision_test.cc"],
|
||||
|
178
tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
Normal file
178
tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
Normal file
@ -0,0 +1,178 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||
|
||||
#ifdef __FAST_MATH__
|
||||
#error("Can't be compiled with fast math on");
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
template <PrimitiveType T>
|
||||
using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
|
||||
|
||||
// Exhaustive test for binary operations for 16 bit floating point types,
|
||||
// including float16 and bfloat.
|
||||
//
|
||||
// Test parameter is a pair of (begin, end) for range under test.
|
||||
template <
|
||||
PrimitiveType T,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
|
||||
half>::value ||
|
||||
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
|
||||
bfloat16>::value>::type* = nullptr>
|
||||
class Exhaustive16BitBinaryTest
|
||||
: public ExhaustiveBinaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::pair<int64, int64>> {
|
||||
public:
|
||||
int64 GetInputSize() override {
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
return end - begin;
|
||||
}
|
||||
|
||||
// Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for
|
||||
// the values of src0 and src1 for a 16 bit binary operation being tested,
|
||||
// and generates the cartesian product of the two sets as the two inputs for
|
||||
// the test.
|
||||
void FillInput(std::array<Literal, 2>* input_literals) override {
|
||||
int64 input_size = GetInputSize();
|
||||
CHECK_EQ(input_size, (*input_literals)[0].element_count());
|
||||
CHECK_EQ(input_size, (*input_literals)[1].element_count());
|
||||
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
VLOG(2) << "Checking range [" << begin << ", " << end << "]";
|
||||
|
||||
absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
|
||||
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
uint32 input_val = i + begin;
|
||||
// Convert the lower 16 bits to the NativeT and replaced known incorrect
|
||||
// input values with 0.
|
||||
input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
|
||||
input_arr_1[i] =
|
||||
ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using typename ExhaustiveBinaryTest<T>::NativeT;
|
||||
using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
|
||||
};
|
||||
|
||||
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
|
||||
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
|
||||
|
||||
// Returns a wrapper of the given build method, which build an HLO operation
|
||||
// with an empty broadcast dimension.
|
||||
inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
|
||||
std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
|
||||
return [&](XlaOp src0, XlaOp src1) -> XlaOp {
|
||||
return build_method(src0, src1, {});
|
||||
};
|
||||
}
|
||||
|
||||
#define XLA_TEST_16BIT(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
|
||||
__VA_ARGS__ \
|
||||
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
|
||||
XLA_TEST_16BIT(Add, {
|
||||
auto host_add = [](float x, float y) { return x + y; };
|
||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||
})
|
||||
|
||||
XLA_TEST_16BIT(Sub, {
|
||||
auto host_sub = [](float x, float y) { return x - y; };
|
||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||
})
|
||||
|
||||
// TODO(bixia): Mul fails with bfloat16 on CPU.
|
||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), {
|
||||
auto host_mul = [](float x, float y) { return x * y; };
|
||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||
})
|
||||
|
||||
// TODO(bixia): Div fails with bfloat16 on CPU.
|
||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Div), {
|
||||
auto host_div = [](float x, float y) { return x / y; };
|
||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||
})
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMax(T x, T y) {
|
||||
// We need to propagate NAN here becasue std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::max<T>(x, y);
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMin(T x, T y) {
|
||||
// We need to propagate NAN here becasue std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::min<T>(x, y);
|
||||
}
|
||||
|
||||
XLA_TEST_16BIT(Max,
|
||||
{ Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); })
|
||||
|
||||
XLA_TEST_16BIT(Min,
|
||||
{ Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); })
|
||||
|
||||
// TODO(bixia): Pow fails with bfloat16 on CPU.
|
||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Pow),
|
||||
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
|
||||
|
||||
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
|
||||
XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2),
|
||||
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F16)
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_BF16)
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -171,6 +171,7 @@ inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
|
||||
typename ExhaustiveOpTestBase<T, N>::NativeT) {
|
||||
LOG(FATAL) << "Unhandled Type";
|
||||
}
|
||||
|
||||
template <PrimitiveType T, size_t N>
|
||||
inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
|
||||
typename ExhaustiveOpTestBase<T, N>::NativeT,
|
||||
@ -213,6 +214,18 @@ inline ExhaustiveOpTestBase<BF16, 1>::ErrorSpec DefaultSpecGenerator<BF16, 1>(
|
||||
bfloat16) {
|
||||
return ExhaustiveOpTestBase<BF16, 1>::ErrorSpec{0.002, 0.02};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ExhaustiveOpTestBase<F16, 2>::ErrorSpec DefaultSpecGenerator<F16, 2>(
|
||||
Eigen::half, Eigen::half) {
|
||||
return ExhaustiveOpTestBase<F16, 2>::ErrorSpec{0.001, 0.001};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline ExhaustiveOpTestBase<BF16, 2>::ErrorSpec DefaultSpecGenerator<BF16, 2>(
|
||||
bfloat16, bfloat16) {
|
||||
return ExhaustiveOpTestBase<BF16, 2>::ErrorSpec{0.002, 0.02};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/
|
||||
@ -229,4 +242,7 @@ template class ExhaustiveOpTestBase<F32, 1>;
|
||||
template class ExhaustiveOpTestBase<F16, 1>;
|
||||
template class ExhaustiveOpTestBase<BF16, 1>;
|
||||
|
||||
template class ExhaustiveOpTestBase<F16, 2>;
|
||||
template class ExhaustiveOpTestBase<BF16, 2>;
|
||||
|
||||
} // namespace xla
|
||||
|
@ -85,41 +85,61 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
};
|
||||
|
||||
// Native types that correspond to the primtive types above.
|
||||
typedef typename primitive_util::PrimitiveTypeToNative<T>::type NativeT;
|
||||
typedef typename primitive_util::PrimitiveTypeToNative<RefT::value>::type
|
||||
NativeRefT;
|
||||
typedef
|
||||
typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type
|
||||
ComponentNativeT;
|
||||
typedef
|
||||
typename primitive_util::PrimitiveTypeToNative<ComponentRefT::value>::type
|
||||
ComponentNativeRefT;
|
||||
typedef typename primitive_util::PrimitiveTypeToNative<
|
||||
ComponentIntegralT::value>::type ComponentIntegralNativeT;
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
|
||||
using NativeRefT =
|
||||
typename primitive_util::PrimitiveTypeToNative<RefT::value>::type;
|
||||
using ComponentNativeT =
|
||||
typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type;
|
||||
using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative<
|
||||
ComponentRefT::value>::type;
|
||||
using ComponentIntegralNativeT =
|
||||
typename primitive_util::PrimitiveTypeToNative<
|
||||
ComponentIntegralT::value>::type;
|
||||
|
||||
typedef std::array<Literal, N> InputLiterals;
|
||||
using InputLiterals = std::array<Literal, N>;
|
||||
|
||||
private:
|
||||
// N spans corresponding to the list of literal data values.
|
||||
typedef std::array<absl::Span<const NativeT>, N> NativeInputsList;
|
||||
using NativeInputsList = std::array<absl::Span<const NativeT>, N>;
|
||||
|
||||
// N data items representing a single input to some XLA function.
|
||||
typedef std::array<NativeT, N> NativeInputs;
|
||||
// N data items representing a single input to an XLA function.
|
||||
using NativeInputs = std::array<NativeT, N>;
|
||||
|
||||
// N data items representing a single input to some interpreter backend
|
||||
// N data items representing a single input to an interpreter backend
|
||||
// function.
|
||||
typedef std::array<NativeRefT, N> NativeRefInputs;
|
||||
using NativeRefInputs = std::array<NativeRefT, N>;
|
||||
|
||||
// N data items representing a single input to an XLA function.
|
||||
using XlaInputs = std::array<XlaOp, N>;
|
||||
|
||||
// Representations of the reference function passed in by the user.
|
||||
template <size_t K>
|
||||
struct EvaluateOpWrapper {};
|
||||
template <>
|
||||
struct EvaluateOpWrapper<1> {
|
||||
typedef NativeRefT (*type)(NativeRefT);
|
||||
using type = NativeRefT (*)(NativeRefT);
|
||||
};
|
||||
template <>
|
||||
struct EvaluateOpWrapper<2> {
|
||||
typedef NativeRefT (*type)(NativeRefT, NativeRefT);
|
||||
using type = NativeRefT (*)(NativeRefT, NativeRefT);
|
||||
};
|
||||
|
||||
// Representations of the reference function passed in by the user.
|
||||
template <size_t K>
|
||||
struct EnqueueOpWrapper {};
|
||||
template <>
|
||||
struct EnqueueOpWrapper<1> {
|
||||
using type = std::function<XlaOp(XlaOp)>;
|
||||
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
|
||||
return ty(inputs[0]);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct EnqueueOpWrapper<2> {
|
||||
using type = std::function<XlaOp(XlaOp, XlaOp)>;
|
||||
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
|
||||
return ty(inputs[0], inputs[1]);
|
||||
}
|
||||
};
|
||||
|
||||
// Representations of the ErrorSpecGen function passed in by the user.
|
||||
@ -127,16 +147,17 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
struct ErrorSpecGenWrapper {};
|
||||
template <>
|
||||
struct ErrorSpecGenWrapper<1> {
|
||||
typedef ErrorSpec (*type)(NativeT);
|
||||
using type = ErrorSpec (*)(NativeT);
|
||||
};
|
||||
template <>
|
||||
struct ErrorSpecGenWrapper<2> {
|
||||
typedef ErrorSpec (*type)(NativeT, NativeT);
|
||||
using type = ErrorSpec (*)(NativeT, NativeT);
|
||||
};
|
||||
|
||||
public:
|
||||
using ErrorSpecGen = typename ErrorSpecGenWrapper<N>::type;
|
||||
using EvaluateOp = typename EvaluateOpWrapper<N>::type;
|
||||
using EnqueueOp = typename EnqueueOpWrapper<N>::type;
|
||||
|
||||
explicit ExhaustiveOpTestBase()
|
||||
: ty_(T), platform_(client_->platform()->Name()) {
|
||||
@ -147,7 +168,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
mutable_debug_options()->clear_xla_disable_hlo_passes();
|
||||
}
|
||||
|
||||
void Run(std::function<XlaOp(XlaOp)> enqueue_op, EvaluateOp evaluate_op) {
|
||||
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
|
||||
Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator());
|
||||
}
|
||||
|
||||
@ -158,16 +179,18 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
// We use a function pointer for evaluate_op for performance because it is
|
||||
// called each time an output element is compared inside a loop in routine
|
||||
// ExpectNear.
|
||||
void Run(std::function<XlaOp(XlaOp)> enqueue_op, EvaluateOp evaluate_op,
|
||||
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
|
||||
ErrorSpecGen error_spec_gen) {
|
||||
InputLiterals input_literals = CreateInputLiterals();
|
||||
FillInput(&input_literals);
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaInputs xla_inputs;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
enqueue_op(Parameter(&builder, i, input_literals[i].shape(), "input"));
|
||||
xla_inputs[i] =
|
||||
Parameter(&builder, i, input_literals[i].shape(), "input");
|
||||
}
|
||||
EnqueueOpWrapper<N>::BuildFromInputs(xla_inputs, enqueue_op);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
|
||||
@ -350,7 +373,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
std::vector<std::complex<ComponentNativeRefT>>
|
||||
GetTestValuesWithSubnormalSubstitutions(
|
||||
std::complex<ComponentNativeRefT> value) {
|
||||
typedef std::complex<ComponentNativeRefT> complex;
|
||||
using complex = std::complex<ComponentNativeRefT>;
|
||||
|
||||
auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real());
|
||||
auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag());
|
||||
@ -738,8 +761,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
bool relaxed_denormal_signs_ = platform_ != "CUDA";
|
||||
|
||||
private:
|
||||
typedef NativeRefT (*EvaluateOpInternal)(NativeRefInputs);
|
||||
typedef ErrorSpec (*ErrorSpecGenInternal)(NativeInputs);
|
||||
using EvaluateOpInternal = NativeRefT (*)(NativeRefInputs);
|
||||
using ErrorSpecGenInternal = ErrorSpec (*)(NativeInputs);
|
||||
|
||||
template <typename Type, typename FuncPtr>
|
||||
ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
|
||||
@ -1026,10 +1049,10 @@ class FpValues {
|
||||
std::array<int, kTotalBitChunks + 1> offsets_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
int GetMantissaTotalBits() {
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Only supports float and double.");
|
||||
return std::numeric_limits<T>::digits - 1;
|
||||
}
|
||||
|
||||
@ -1053,10 +1076,10 @@ uint64 GetAllOneExponent() {
|
||||
return (1ull << GetExponentTotalBits<T>()) - 1ull;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Only supports float and double.");
|
||||
int total_bits = GetFpTotalBits<T>();
|
||||
return FpValues({mantissa, exponent, sign},
|
||||
{0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
|
||||
@ -1167,5 +1190,16 @@ std::vector<FpValues> CreateFpValuesForBoundaryTest() {
|
||||
GetNans<T>(1000)};
|
||||
}
|
||||
|
||||
inline std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
|
||||
// We break up the 2^32-element space into small'ish chunks to keep peak
|
||||
// memory usage low.
|
||||
std::vector<std::pair<int64, int64>> result;
|
||||
const int64 step = 1 << 25;
|
||||
for (int64 i = 0; i < (1l << 32); i += step) {
|
||||
result.push_back({i, i + step});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
|
||||
|
@ -538,16 +538,6 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
|
||||
XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
|
||||
std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
|
||||
// We break up the 2^32-element space into small'ish chunks to keep peak
|
||||
// memory usage low.
|
||||
std::vector<std::pair<int64, int64>> result;
|
||||
const int64 step = 1 << 25;
|
||||
for (int64 i = 0; i < (1l << 32); i += step) {
|
||||
result.push_back({i, i + step});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
|
Loading…
Reference in New Issue
Block a user