[XLA] Add exhaustive test for F16 and BF16 binary operations.

This requires extending the ExhaustiveOpTestBase::run methods to support binary
operation and moving CreateExhaustiveF32Ranges from exhaustive_unary_test to
exhaustive_op_test_utils.

PiperOrigin-RevId: 261786687
This commit is contained in:
Bixia Zheng 2019-08-05 15:47:17 -07:00 committed by TensorFlower Gardener
parent 0094a421fc
commit 66ba95a89c
5 changed files with 302 additions and 44 deletions

View File

@ -789,6 +789,46 @@ xla_test(
],
)
xla_test(
name = "exhaustive_binary_test_f16",
srcs = ["exhaustive_binary_test.cc"],
backends = [
"gpu",
"cpu",
],
copts = ["-DBINARY_TEST_TARGET_F16"],
real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
"optonly",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
deps = [
":exhaustive_op_test_utils",
],
)
xla_test(
name = "exhaustive_binary_test_bf16",
srcs = ["exhaustive_binary_test.cc"],
backends = [
"gpu",
"cpu",
],
copts = ["-DBINARY_TEST_TARGET_BF16"],
real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
"optonly",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
deps = [
":exhaustive_op_test_utils",
],
)
xla_test(
name = "reduce_precision_test",
srcs = ["reduce_precision_test.cc"],

View File

@ -0,0 +1,178 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
#ifdef __FAST_MATH__
#error("Can't be compiled with fast math on");
#endif
namespace xla {
namespace {
template <PrimitiveType T>
using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
// Exhaustive test for binary operations for 16 bit floating point types,
// including float16 and bfloat.
//
// Test parameter is a pair of (begin, end) for range under test.
template <
PrimitiveType T,
typename std::enable_if<
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
half>::value ||
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
bfloat16>::value>::type* = nullptr>
class Exhaustive16BitBinaryTest
: public ExhaustiveBinaryTest<T>,
public ::testing::WithParamInterface<std::pair<int64, int64>> {
public:
int64 GetInputSize() override {
int64 begin, end;
std::tie(begin, end) = GetParam();
return end - begin;
}
// Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for
// the values of src0 and src1 for a 16 bit binary operation being tested,
// and generates the cartesian product of the two sets as the two inputs for
// the test.
void FillInput(std::array<Literal, 2>* input_literals) override {
int64 input_size = GetInputSize();
CHECK_EQ(input_size, (*input_literals)[0].element_count());
CHECK_EQ(input_size, (*input_literals)[1].element_count());
int64 begin, end;
std::tie(begin, end) = GetParam();
VLOG(2) << "Checking range [" << begin << ", " << end << "]";
absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
for (int64 i = 0; i < input_size; i++) {
uint32 input_val = i + begin;
// Convert the lower 16 bits to the NativeT and replaced known incorrect
// input values with 0.
input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
input_arr_1[i] =
ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0);
}
}
protected:
using typename ExhaustiveBinaryTest<T>::NativeT;
using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
};
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
// Returns a wrapper of the given build method, which build an HLO operation
// with an empty broadcast dimension.
inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
return [&](XlaOp src0, XlaOp src1) -> XlaOp {
return build_method(src0, src1, {});
};
}
#define XLA_TEST_16BIT(test_name, ...) \
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
__VA_ARGS__ \
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
__VA_ARGS__
XLA_TEST_16BIT(Add, {
auto host_add = [](float x, float y) { return x + y; };
Run(AddEmptyBroadcastDimension(Add), host_add);
})
XLA_TEST_16BIT(Sub, {
auto host_sub = [](float x, float y) { return x - y; };
Run(AddEmptyBroadcastDimension(Sub), host_sub);
})
// TODO(bixia): Mul fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), {
auto host_mul = [](float x, float y) { return x * y; };
Run(AddEmptyBroadcastDimension(Mul), host_mul);
})
// TODO(bixia): Div fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Div), {
auto host_div = [](float x, float y) { return x / y; };
Run(AddEmptyBroadcastDimension(Div), host_div);
})
template <typename T, typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, double>::value>::type* = nullptr>
T ReferenceMax(T x, T y) {
// We need to propagate NAN here becasue std::max may not propagate NAN.
if (std::fpclassify(x) == FP_NAN) {
return x;
}
if (std::fpclassify(y) == FP_NAN) {
return y;
}
return std::max<T>(x, y);
}
template <typename T, typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, double>::value>::type* = nullptr>
T ReferenceMin(T x, T y) {
// We need to propagate NAN here becasue std::max may not propagate NAN.
if (std::fpclassify(x) == FP_NAN) {
return x;
}
if (std::fpclassify(y) == FP_NAN) {
return y;
}
return std::min<T>(x, y);
}
XLA_TEST_16BIT(Max,
{ Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>); })
XLA_TEST_16BIT(Min,
{ Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>); })
// TODO(bixia): Pow fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Pow),
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2),
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
#if defined(BINARY_TEST_TARGET_F16)
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
#endif
#endif
#if defined(BINARY_TEST_TARGET_BF16)
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
#endif
#endif
} // namespace
} // namespace xla

View File

@ -171,6 +171,7 @@ inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
typename ExhaustiveOpTestBase<T, N>::NativeT) {
LOG(FATAL) << "Unhandled Type";
}
template <PrimitiveType T, size_t N>
inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
typename ExhaustiveOpTestBase<T, N>::NativeT,
@ -213,6 +214,18 @@ inline ExhaustiveOpTestBase<BF16, 1>::ErrorSpec DefaultSpecGenerator<BF16, 1>(
bfloat16) {
return ExhaustiveOpTestBase<BF16, 1>::ErrorSpec{0.002, 0.02};
}
template <>
inline ExhaustiveOpTestBase<F16, 2>::ErrorSpec DefaultSpecGenerator<F16, 2>(
Eigen::half, Eigen::half) {
return ExhaustiveOpTestBase<F16, 2>::ErrorSpec{0.001, 0.001};
}
template <>
inline ExhaustiveOpTestBase<BF16, 2>::ErrorSpec DefaultSpecGenerator<BF16, 2>(
bfloat16, bfloat16) {
return ExhaustiveOpTestBase<BF16, 2>::ErrorSpec{0.002, 0.02};
}
} // namespace
/*static*/
@ -229,4 +242,7 @@ template class ExhaustiveOpTestBase<F32, 1>;
template class ExhaustiveOpTestBase<F16, 1>;
template class ExhaustiveOpTestBase<BF16, 1>;
template class ExhaustiveOpTestBase<F16, 2>;
template class ExhaustiveOpTestBase<BF16, 2>;
} // namespace xla

View File

@ -85,41 +85,61 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
};
// Native types that correspond to the primtive types above.
typedef typename primitive_util::PrimitiveTypeToNative<T>::type NativeT;
typedef typename primitive_util::PrimitiveTypeToNative<RefT::value>::type
NativeRefT;
typedef
typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type
ComponentNativeT;
typedef
typename primitive_util::PrimitiveTypeToNative<ComponentRefT::value>::type
ComponentNativeRefT;
typedef typename primitive_util::PrimitiveTypeToNative<
ComponentIntegralT::value>::type ComponentIntegralNativeT;
using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
using NativeRefT =
typename primitive_util::PrimitiveTypeToNative<RefT::value>::type;
using ComponentNativeT =
typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type;
using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative<
ComponentRefT::value>::type;
using ComponentIntegralNativeT =
typename primitive_util::PrimitiveTypeToNative<
ComponentIntegralT::value>::type;
typedef std::array<Literal, N> InputLiterals;
using InputLiterals = std::array<Literal, N>;
private:
// N spans corresponding to the list of literal data values.
typedef std::array<absl::Span<const NativeT>, N> NativeInputsList;
using NativeInputsList = std::array<absl::Span<const NativeT>, N>;
// N data items representing a single input to some XLA function.
typedef std::array<NativeT, N> NativeInputs;
// N data items representing a single input to an XLA function.
using NativeInputs = std::array<NativeT, N>;
// N data items representing a single input to some interpreter backend
// N data items representing a single input to an interpreter backend
// function.
typedef std::array<NativeRefT, N> NativeRefInputs;
using NativeRefInputs = std::array<NativeRefT, N>;
// N data items representing a single input to an XLA function.
using XlaInputs = std::array<XlaOp, N>;
// Representations of the reference function passed in by the user.
template <size_t K>
struct EvaluateOpWrapper {};
template <>
struct EvaluateOpWrapper<1> {
typedef NativeRefT (*type)(NativeRefT);
using type = NativeRefT (*)(NativeRefT);
};
template <>
struct EvaluateOpWrapper<2> {
typedef NativeRefT (*type)(NativeRefT, NativeRefT);
using type = NativeRefT (*)(NativeRefT, NativeRefT);
};
// Representations of the reference function passed in by the user.
template <size_t K>
struct EnqueueOpWrapper {};
template <>
struct EnqueueOpWrapper<1> {
using type = std::function<XlaOp(XlaOp)>;
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
return ty(inputs[0]);
}
};
template <>
struct EnqueueOpWrapper<2> {
using type = std::function<XlaOp(XlaOp, XlaOp)>;
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
return ty(inputs[0], inputs[1]);
}
};
// Representations of the ErrorSpecGen function passed in by the user.
@ -127,16 +147,17 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
struct ErrorSpecGenWrapper {};
template <>
struct ErrorSpecGenWrapper<1> {
typedef ErrorSpec (*type)(NativeT);
using type = ErrorSpec (*)(NativeT);
};
template <>
struct ErrorSpecGenWrapper<2> {
typedef ErrorSpec (*type)(NativeT, NativeT);
using type = ErrorSpec (*)(NativeT, NativeT);
};
public:
using ErrorSpecGen = typename ErrorSpecGenWrapper<N>::type;
using EvaluateOp = typename EvaluateOpWrapper<N>::type;
using EnqueueOp = typename EnqueueOpWrapper<N>::type;
explicit ExhaustiveOpTestBase()
: ty_(T), platform_(client_->platform()->Name()) {
@ -147,7 +168,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
mutable_debug_options()->clear_xla_disable_hlo_passes();
}
void Run(std::function<XlaOp(XlaOp)> enqueue_op, EvaluateOp evaluate_op) {
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator());
}
@ -158,16 +179,18 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// We use a function pointer for evaluate_op for performance because it is
// called each time an output element is compared inside a loop in routine
// ExpectNear.
void Run(std::function<XlaOp(XlaOp)> enqueue_op, EvaluateOp evaluate_op,
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
ErrorSpecGen error_spec_gen) {
InputLiterals input_literals = CreateInputLiterals();
FillInput(&input_literals);
XlaBuilder builder(TestName());
XlaInputs xla_inputs;
for (int i = 0; i < N; ++i) {
enqueue_op(Parameter(&builder, i, input_literals[i].shape(), "input"));
xla_inputs[i] =
Parameter(&builder, i, input_literals[i].shape(), "input");
}
EnqueueOpWrapper<N>::BuildFromInputs(xla_inputs, enqueue_op);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
@ -350,7 +373,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
std::vector<std::complex<ComponentNativeRefT>>
GetTestValuesWithSubnormalSubstitutions(
std::complex<ComponentNativeRefT> value) {
typedef std::complex<ComponentNativeRefT> complex;
using complex = std::complex<ComponentNativeRefT>;
auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real());
auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag());
@ -738,8 +761,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
bool relaxed_denormal_signs_ = platform_ != "CUDA";
private:
typedef NativeRefT (*EvaluateOpInternal)(NativeRefInputs);
typedef ErrorSpec (*ErrorSpecGenInternal)(NativeInputs);
using EvaluateOpInternal = NativeRefT (*)(NativeRefInputs);
using ErrorSpecGenInternal = ErrorSpec (*)(NativeInputs);
template <typename Type, typename FuncPtr>
ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
@ -1026,10 +1049,10 @@ class FpValues {
std::array<int, kTotalBitChunks + 1> offsets_;
};
template <typename T>
template <typename T, typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, double>::value>::type* = nullptr>
int GetMantissaTotalBits() {
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Only supports float and double.");
return std::numeric_limits<T>::digits - 1;
}
@ -1053,10 +1076,10 @@ uint64 GetAllOneExponent() {
return (1ull << GetExponentTotalBits<T>()) - 1ull;
}
template <typename T>
template <typename T, typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, double>::value>::type* = nullptr>
FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Only supports float and double.");
int total_bits = GetFpTotalBits<T>();
return FpValues({mantissa, exponent, sign},
{0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
@ -1167,5 +1190,16 @@ std::vector<FpValues> CreateFpValuesForBoundaryTest() {
GetNans<T>(1000)};
}
inline std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
// We break up the 2^32-element space into small'ish chunks to keep peak
// memory usage low.
std::vector<std::pair<int64, int64>> result;
const int64 step = 1 << 25;
for (int64 i = 0; i < (1l << 32); i += step) {
result.push_back({i, i + step});
}
return result;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_

View File

@ -538,16 +538,6 @@ XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, {
XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); })
#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER)
std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
// We break up the 2^32-element space into small'ish chunks to keep peak
// memory usage low.
std::vector<std::pair<int64, int64>> result;
const int64 step = 1 << 25;
for (int64 i = 0; i < (1l << 32); i += step) {
result.push_back({i, i + step});
}
return result;
}
INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));