[XLA] Support the reduce-precision HLO for all floating-point types.

PiperOrigin-RevId: 258823530
This commit is contained in:
Justin Lebar 2019-07-18 12:35:58 -07:00 committed by TensorFlower Gardener
parent c68ed58053
commit 4bd8619abf
3 changed files with 541 additions and 200 deletions

View File

@ -63,28 +63,44 @@ int64 GlobalRandomValue() {
return rng();
}
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
int64 mantissa_bits,
StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
llvm::Value* x,
int64 dest_exponent_bits,
int64 dest_mantissa_bits,
llvm::IRBuilder<>* b) {
using llvm::APInt;
if (!primitive_util::IsFloatingPointType(src_ty)) {
return Unimplemented(
"ReducePrecision cannot accept non-floating-point type %s.",
PrimitiveType_Name(src_ty));
}
// Integer and float types for casting and constant generation.
llvm::Type* float_type = x->getType();
llvm::IntegerType* int_type = b->getInt32Ty();
int64 nbits = float_type->getPrimitiveSizeInBits();
llvm::IntegerType* int_type = b->getIntNTy(nbits);
// SignificandWidth includes the implicit extra bit.
int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
int src_exponent_bits = nbits - 1 - src_mantissa_bits;
// Cast the input value to an integer for bitwise manipulation.
llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
if (mantissa_bits < 23) {
if (dest_mantissa_bits < src_mantissa_bits) {
// Last remaining mantissa bit.
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
APInt last_mantissa_bit_mask(nbits, 1);
last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
// Compute rounding bias for round-to-nearest with ties to even. This is
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
llvm::Value* x_last_mantissa_bit = b->CreateLShr(
b->CreateAnd(x_as_int,
llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
(23 - mantissa_bits));
(src_mantissa_bits - dest_mantissa_bits));
llvm::Value* x_rounding_bias =
b->CreateAdd(x_last_mantissa_bit,
llvm::ConstantInt::get(int_type, base_rounding_bias));
@ -93,16 +109,19 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
x_as_int = b->CreateAnd(x_as_int,
llvm::ConstantInt::get(int_type, truncation_mask));
}
if (exponent_bits < 8) {
// Masks for f32 values.
const uint32_t f32_sign_bit_mask = 1u << 31;
const uint32_t f32_exp_bits_mask = 0xffu << 23;
if (dest_exponent_bits < src_exponent_bits) {
APInt sign_bit_mask(nbits, 1);
sign_bit_mask <<= nbits - 1;
APInt exp_bits_mask(nbits, 1);
exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
<< src_mantissa_bits;
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
@ -116,28 +135,31 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
// (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
const uint32_t f32_exponent_bias = (1 << 7) - 1;
const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
const uint32_t reduced_max_exponent =
f32_exponent_bias + reduced_exponent_bias;
const uint32_t reduced_min_exponent =
f32_exponent_bias - reduced_exponent_bias;
APInt exponent_bias(nbits, 1);
exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
APInt reduced_exponent_bias(nbits, 1);
reduced_exponent_bias =
(reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
// Do we overflow or underflow?
llvm::Value* x_exponent = b->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
llvm::Value* x_exponent =
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
llvm::Value* x_overflows = b->CreateICmpUGT(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
x_exponent, llvm::ConstantInt::get(
int_type, reduced_max_exponent << src_mantissa_bits));
llvm::Value* x_underflows = b->CreateICmpULE(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
x_exponent, llvm::ConstantInt::get(
int_type, reduced_min_exponent << src_mantissa_bits));
// Compute appropriately-signed values of zero and infinity.
llvm::Value* x_signed_zero = b->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
llvm::Value* x_signed_zero =
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
llvm::Value* x_signed_inf = b->CreateOr(
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
@ -161,7 +183,7 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
if (!b->getFastMathFlags().noNaNs()) {
llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
if (mantissa_bits > 0) {
if (dest_mantissa_bits > 0) {
result = b->CreateSelect(x_is_nan, x, result);
} else {
result = b->CreateSelect(
@ -171,11 +193,14 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
return result;
}
llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
auto reduced_precision = EmitReducePrecisionFloat(
f32_value,
/*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
/*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
llvm::IRBuilder<>* b) {
TF_ASSIGN_OR_RETURN(
auto reduced_precision,
EmitReducePrecisionIR(
/*src_ty=*/F32, f32_value,
/*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits,
/*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b));
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
auto shifted = b->CreateLShr(as_int32, 16);
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
@ -1099,11 +1124,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) {
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
/*mantissa_bits=*/hlo->mantissa_bits(), b_);
return EmitReducePrecisionIR(
/*src_ty=*/hlo->operand(0)->shape().element_type(), x,
/*dest_exponent_bits=*/hlo->exponent_bits(),
/*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
}
static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
@ -1990,7 +2014,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
llvm::Value* float_val =
b_->CreateUIToFP(elem_index_linear, float_ir_type);
if (component_element_type == BF16) {
iota_result = EmitF32ToBF16(float_val, b_);
TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
} else {
iota_result = float_val;
}

View File

@ -67,6 +67,30 @@ T ToArithmeticSafeType(T t) {
return std::move(t);
}
// UintWithSize<N> gets an unsigned integer with the given size in bytes.
template <size_t Bytes>
struct UintWithSize {};
template <>
struct UintWithSize<1> {
using type = uint8;
};
template <>
struct UintWithSize<2> {
using type = uint16;
};
template <>
struct UintWithSize<4> {
using type = uint32;
};
template <>
struct UintWithSize<8> {
using type = uint64;
};
// Templated DfsHloVisitor for use by HloEvaluator.
//
// Typically ReturnT here indicates the resulting literal type of each evaluated
@ -2389,48 +2413,57 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleCos<ElementwiseT>(cos);
}
template <typename NativeT, typename std::enable_if<std::is_same<
float, NativeT>::value>::type* = nullptr>
template <typename NativeT,
typename std::enable_if<
std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value>::type* = nullptr>
Status HandleReducePrecision(HloInstruction* reduce_precision) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[reduce_precision],
ElementWiseUnaryOp(reduce_precision, [reduce_precision](
ElementwiseT elem) {
uint32_t value_as_int = absl::bit_cast<uint32_t>(elem);
const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
const uint32_t exponent_bits = reduce_precision->exponent_bits();
ElementWiseUnaryOp(reduce_precision, [&](ElementwiseT elem) {
const uint32 src_mantissa_bits =
std::numeric_limits<NativeT>::digits - 1;
const uint32 src_exponent_bits =
8 * sizeof(NativeT) - src_mantissa_bits - 1;
const uint32 dest_mantissa_bits = reduce_precision->mantissa_bits();
const uint32 dest_exponent_bits = reduce_precision->exponent_bits();
using Uint = typename UintWithSize<sizeof(NativeT)>::type;
Uint value_as_int = absl::bit_cast<Uint>(elem);
// Code is based on the CPU/GPU implementation in LLVM-emitting code.
//
// Bits in float type:
// Bits in float32 type:
// mantissa : bits [0:22]
// exponent : bits [23:30]
// sign : bits [31]
if (mantissa_bits < 23) {
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
if (dest_mantissa_bits < src_mantissa_bits) {
const Uint last_mantissa_bit_mask =
Uint{1} << (src_mantissa_bits - dest_mantissa_bits);
// Compute rounding bias for round-to-nearest with ties to even.
// This is equal to a base value of 0111... plus one bit if the last
// remaining mantissa bit is 1.
const uint32_t base_rounding_bias =
(last_mantissa_bit_mask >> 1) - 1;
const uint32_t x_last_mantissa_bit =
(value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
const uint32_t x_rounding_bias =
const Uint base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
const Uint x_last_mantissa_bit =
(value_as_int & last_mantissa_bit_mask) >>
(src_mantissa_bits - dest_mantissa_bits);
const Uint x_rounding_bias =
x_last_mantissa_bit + base_rounding_bias;
// Add rounding bias, and mask out truncated bits. Note that the
// case where adding the rounding bias overflows into the exponent
// bits is correct; the non-masked mantissa bits will all be zero,
// and the exponent will be incremented by one.
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
const Uint truncation_mask = ~(last_mantissa_bit_mask - 1);
value_as_int = value_as_int + x_rounding_bias;
value_as_int = value_as_int & truncation_mask;
}
if (exponent_bits < 8) {
if (dest_exponent_bits < src_exponent_bits) {
// Masks for f32 values.
const uint32_t f32_sign_bit_mask = 1u << 31;
const uint32_t f32_exp_bits_mask = 0xffu << 23;
const Uint sign_bit_mask = Uint{1} << 8 * sizeof(NativeT) - 1;
const Uint exp_bits_mask = (Uint{1 << src_exponent_bits} - 1)
<< src_mantissa_bits;
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
// most- significant bit -- is equal to 1.0f for all exponent sizes.
@ -2444,23 +2477,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// is (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
const uint32_t f32_exponent_bias = (1 << 7) - 1;
const uint32_t reduced_exponent_bias =
(1 << (exponent_bits - 1)) - 1;
const uint32_t reduced_max_exponent =
f32_exponent_bias + reduced_exponent_bias;
const uint32_t reduced_min_exponent =
f32_exponent_bias - reduced_exponent_bias;
const Uint exponent_bias = (Uint{1} << (src_exponent_bits - 1)) - 1;
const Uint reduced_exponent_bias =
(1 << (dest_exponent_bits - 1)) - 1;
const Uint reduced_max_exponent =
exponent_bias + reduced_exponent_bias;
const Uint reduced_min_exponent =
exponent_bias - reduced_exponent_bias;
// Do we overflow or underflow?
const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
const Uint x_exponent = value_as_int & exp_bits_mask;
const bool x_overflows =
x_exponent > (reduced_max_exponent << src_mantissa_bits);
const bool x_underflows =
x_exponent <= (reduced_min_exponent << 23);
x_exponent <= (reduced_min_exponent << src_mantissa_bits);
// Compute appropriately-signed values of zero and infinity.
const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
const Uint x_signed_zero = value_as_int & sign_bit_mask;
const Uint x_signed_inf = x_signed_zero | exp_bits_mask;
// Force to zero or infinity if overflow or underflow. (Note that
// this truncates all denormal values to zero, rather than rounding
@ -2469,23 +2503,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
value_as_int = x_underflows ? x_signed_zero : value_as_int;
}
float reduced_result = absl::bit_cast<float>(value_as_int);
NativeT reduced_result = absl::bit_cast<NativeT>(value_as_int);
if (std::isnan(elem)) {
reduced_result = mantissa_bits > 0
reduced_result = dest_mantissa_bits > 0
? elem
: std::numeric_limits<float>::infinity();
: std::numeric_limits<NativeT>::infinity();
}
return reduced_result;
}));
return Status::OK();
}
template <typename NativeT, typename std::enable_if<std::is_same<
double, NativeT>::value>::type* = nullptr>
Status HandleReducePrecision(HloInstruction* reduce_precision) {
return InvalidArgument("Double is not supported for reduce precision");
}
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value ||

View File

@ -39,189 +39,479 @@ limitations under the License.
namespace xla {
namespace {
// Tests to confirm that the ReducePrecision operation produces the expected
// numerical values.
class ReducePrecisionAccuracyTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<int> {
};
// For reduction to IEEE-f16, we want to test the following cases, in both
// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10
// mantissa bits.)
// Testcases in this file work as follows.
//
// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a
// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only
// variants of IEEE-f16.
static const int exponent_sizes[] = {8, 5, 5, 8};
static const int mantissa_sizes[] = {23, 10, 23, 10};
// for ty in {f16, bf16, f32, f64}:
// for (operation_index, (e, m)) in \
// enumerate(zip(ty_exponent_sizes, ty_mantissa_sizes)):
//
// for testcase in ty_test_values:
// let expected = testcase[0]
// let input = testcase[operation_index]
//
// CHECK that XLA-reduce-precision(
// input, /*exp_bits=*/e, /*mantissa_bits=*/m) == expected
//
// Put into words:
//
// - ty_{exponent,mantissa}_sizes tell us the different ways we will reduce the
// precision of `ty`.
//
// - ty_test_values is a 2D array of testcases, each of which is
// len(ty_exponent_sizes) elements long. The first element corresponds to
// the input, and the j'th element corresponds to the expected output of
// doing a reduce-precision with parameters ty_{exponent,mantissa}_sizes[j].
//
// You'll note above that testcase[0] is reused as both expected and input when
// operation_index == 0. This implies that ty_{exponent,mantissa}_sizes[0] must
// be equal to `ty`'s exponent/mantissa size, making the reduce-precision op
// tested a nop.
string TestDataToString(const ::testing::TestParamInfo<int> data) {
int i = data.param;
return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i],
"_mantissa_bits");
}
// We want to test IEEE-f16 (a nop), cases that reduce just the
// mantissa/exponent, and a case that reduces both.
//
// We don't have a lot of tests here, relying instead on the coverage we have of
// f32 and f64.
//
// Note: The hypothetical float(3,7) type we're "converting" to would have:
// max exp = 2^(3-1) - 1 = 3
// min exp = -max_exp + 1 = -2
static const int f16_exponent_sizes[] = {5, 3, 3, 5};
static const int f16_mantissa_sizes[] = {10, 7, 10, 7};
// The FPVAL macro allows us to write out the binary representation of the
// input and expected values in a more readable manner. The mantissa bits
// are separated into the "high" bits (retained with reduction to IEEE-f16)
// and the "low" bits (truncated with reduction to IEEE-f16).
#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA))
// The F16VAL macro lets us write out the binary representation of an f16 in a
// more readable manner, separating out the exponent and mantissa.
#define F16VAL(EXPONENT, MANTISSA) ((EXPONENT << 10) + (MANTISSA))
// Each element in the test-value array consists of four numbers. The first is
// the input value and the following are the expected output values for the
// various precision-reduction cases.
static const uint32_t test_values[][4] = {
static const uint16 f16_test_values[][4] = {
// True zero.
{
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(00000000, 0000000000, 0000000000000) // 0.0
F16VAL(0b00000, 0b0000000000), // 0.0
F16VAL(0b00000, 0b0000000000), // 0.0
F16VAL(0b00000, 0b0000000000), // 0.0
F16VAL(0b00000, 0b0000000000), // 0.0
},
// Largest exponent that underflows to zero.
// One.
{
FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05
F16VAL(0b01111, 0b0000000000), // 1.0
F16VAL(0b01111, 0b0000000000), // 1.0
F16VAL(0b01111, 0b0000000000), // 1.0
F16VAL(0b01111, 0b0000000000), // 1.0
},
// Largest value that rounds to a denormal and thus clamps to zero.
// Largest exponent that underflows to zero is -3, which is encoded as
// -3 + 15 = 12
{
FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05
F16VAL(0b01100, 0b0000000000), // 2^-3
F16VAL(0b00000, 0b0000000000), // 0
F16VAL(0b00000, 0b0000000000), // 0
F16VAL(0b01100, 0b0000000000), // 2^-3
},
// Smallest value that doesn't underflow to zero, due to mantissa rounding
// up and incrementing the exponent out of the denormal range.
{
FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
F16VAL(0b01100, 0b1111111100), // 1020 * 2^-3
F16VAL(0b01101, 0b0000000000), // 2^-2
F16VAL(0b00000, 0b0000000000), // 0
F16VAL(0b01101, 0b0000000000), // 2^-2
},
};
// We want to test bfloat16 (a nop), cases that reduce just the
// mantissa/exponent, and a case that reduces both.
//
// We don't have a lot of tests here, relying instead on the coverage we have of
// f32 and f64.
static const int bf16_exponent_sizes[] = {8, 5, 5, 8};
static const int bf16_mantissa_sizes[] = {7, 5, 7, 5};
// The BF16VAL macro lets us write out the binary representation of a bf16 in a
// more readable manner, separating out the exponent and mantissa.
#define BF16VAL(EXPONENT, MANTISSA) ((EXPONENT << 7) + (MANTISSA))
static const uint16 bf16_test_values[][4] = {
// True zero.
{
BF16VAL(0b00000, 0b0000000000), // 0.0
BF16VAL(0b00000, 0b0000000000), // 0.0
BF16VAL(0b00000, 0b0000000000), // 0.0
BF16VAL(0b00000, 0b0000000000), // 0.0
},
// One.
{
BF16VAL(0b01111111, 0b0000000), // 1.0
BF16VAL(0b01111111, 0b0000000), // 1.0
BF16VAL(0b01111111, 0b0000000), // 1.0
BF16VAL(0b01111111, 0b0000000), // 1.0
},
// Largest exponent that underflows to zero.
{
BF16VAL(0b01110000, 0b0000000), // 3.05176e-05
BF16VAL(0b00000000, 0b0000000), // 0.0
BF16VAL(0b00000000, 0b0000000), // 0.0
BF16VAL(0b01110000, 0b0000000) // 3.05176e-05
},
// Smallest value that doesn't underflow to zero, due to mantissa rounding
// up and incrementing the exponent out of the denormal range.
{
BF16VAL(0b01110000, 0b1111110), // 6.05583e-05
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
BF16VAL(0b00000000, 0b0000000), // 0.0
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
},
};
// We want to test IEEE-f32 (a no-op), IEEE-f16, and exponent-reduction-only and
// mantissa-reduction-only variants of IEEE-f16.
static const int f32_exponent_sizes[] = {8, 5, 5, 8};
static const int f32_mantissa_sizes[] = {23, 10, 23, 10};
// The F32VAL macro allows us to write out the binary representation of the
// input and expected values in a more readable manner. The mantissa bits
// are separated into the "high" bits (retained with reduction to IEEE-f16)
// and the "low" bits (truncated with reduction to IEEE-f16).
#define F32VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
((EXPONENT << 23) + (HIGH_MANTISSA << 13) + (LOW_MANTISSA))
static const uint32 f32_test_values[][4] = {
// True zero.
{
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b00000000, 0b0000000000, 0b0000000000000) // 0.0
},
// Largest exponent that underflows to zero.
{
F32VAL(0b01110000, 0b0000000000, 0b0000000000000), // 3.05176e-05
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b01110000, 0b0000000000, 0b0000000000000) // 3.05176e-05
},
// Largest value that rounds to a denormal and thus clamps to zero.
{
F32VAL(0b01110000, 0b1111111111, 0b0111111111111), // 6.10203e-05
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b01110000, 0b1111111111, 0b0000000000000) // 6.10054e-05
},
// Smallest value that doesn't underflow to zero, due to mantissa rounding
// up and incrementing the exponent out of the denormal range.
{
F32VAL(0b01110000, 0b1111111111, 0b1000000000000), // 6.10203e-05
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
},
// Smallest value that doesn't underflow to zero even without mantissa
// rounding.
{
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
},
// One (to make sure bias-handling is done correctly.
// One (to make sure bias-handling is done correctly).
{
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
},
// Values in a space where ties round down due to ties-to-even:
// Value with highest mantissa that rounds down.
{
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
},
// Value with lowest mantissa that rounds up.
{
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
},
// Values in a space where ties round up due to ties-to-even:
// Value with highest mantissa that rounds down.
{
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
},
// Value with a mantissa that rounds up.
{
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
FPVAL(01111111, 0000000010, 0000000000000), // 1.00195
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
FPVAL(01111111, 0000000010, 0000000000000) // 1.00195
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
F32VAL(0b01111111, 0b0000000010, 0b0000000000000), // 1.00195
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
F32VAL(0b01111111, 0b0000000010, 0b0000000000000) // 1.00195
},
// Largest value that does not overflow to infinity.
{
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
FPVAL(10001110, 1111111111, 0000000000000), // 65504.0
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
FPVAL(10001110, 1111111111, 0000000000000) // 65504.0
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
F32VAL(0b10001110, 0b1111111111, 0b0000000000000), // 65504.0
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
F32VAL(0b10001110, 0b1111111111, 0b0000000000000) // 65504.0
},
// Smallest value that overflows to infinity due to mantissa rounding up.
{
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
},
// Smallest value that overflows to infinity, without mantissa rounding.
{
FPVAL(10001111, 0000000000, 0000000000000), // 65536.0
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
F32VAL(0b10001111, 0b0000000000, 0b0000000000000), // 65536.0
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
},
// Smallest value that overflows to infinity due to mantissa rounding up,
// even when exponent bits aren't reduced.
{
FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000) // Inf
F32VAL(0b11111110, 0b1111111111, 0b1000000000000), // 3.40199e+38
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
},
// True infinity.
{
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000), // Inf
FPVAL(11111111, 0000000000, 0000000000000) // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
},
// NAN with a 1 in the preserved bits.
{
FPVAL(11111111, 1000000000, 0000000000000), // NaN
FPVAL(11111111, 1000000000, 0000000000000), // NaN
FPVAL(11111111, 1000000000, 0000000000000), // NaN
FPVAL(11111111, 1000000000, 0000000000000) // NaN
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
F32VAL(0b11111111, 0b1000000000, 0b0000000000000) // NaN
},
// NAN with a 1 in the truncated bits.
{
FPVAL(11111111, 0000000000, 0000000000001), // NaN
FPVAL(11111111, 0000000000, 0000000000001), // NaN
FPVAL(11111111, 0000000000, 0000000000001), // NaN
FPVAL(11111111, 0000000000, 0000000000001) // NaN
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
F32VAL(0b11111111, 0b0000000000, 0b0000000000001) // NaN
},
// NAN with all ones, causing rounding overflow.
{
FPVAL(11111111, 1111111111, 1111111111111), // NaN
FPVAL(11111111, 1111111111, 1111111111111), // NaN
FPVAL(11111111, 1111111111, 1111111111111), // NaN
FPVAL(11111111, 1111111111, 1111111111111) // NaN
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
F32VAL(0b11111111, 0b1111111111, 0b1111111111111) // NaN
}};
XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
int index = GetParam();
int exponent_bits = exponent_sizes[index];
int mantissa_bits = mantissa_sizes[index];
// F64VAL is like F32VAL but for doubles.
//
// Here the "high" mantissa bits are those retained with reduction to IEEE-f32
// (the first 23 bits), and the "low" bits are those truncated with reduction to
// IEEE-f32 (the remaining 29 bits).
#define F64VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
((uint64{EXPONENT} << 52) + (uint64{HIGH_MANTISSA} << 29) + \
uint64{LOW_MANTISSA})
std::vector<float> input_values;
std::vector<float> expected_values;
// We want to test IEEE-f64 (a no-op), IEEE-f32, and exponent-reduction-only and
// mantissa-reduction-only variants of IEEE-f32.
static const int f64_exponent_sizes[] = {11, 8, 8, 11};
static const int f64_mantissa_sizes[] = {52, 23, 52, 23};
const uint32_t sign_bit = 1u << 31;
static const uint64 f64_test_values[][4] = {
// True zero.
{
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
},
// Largest exponent that underflows to zero, namely -127 (encoded as
// -127 + 1023).
{
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
},
// Largest value that rounds to a denormal and thus clamps to zero.
{
F64VAL(0b01110000000, 0x7FFFFF, 0x0FFFFFFF), // 1.1754943157898258e-38
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b01110000000, 0x7FFFFF, 0x00000000), // 1.1754942807573643e-38
},
// Smallest value that doesn't underflow to zero, due to mantissa rounding
// up and incrementing the exponent out of the denormal range.
{
F64VAL(0b01110000000, 0x7FFFFF, 0x10000000), // 1.1754943157898259e-38
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
},
// Smallest value that doesn't underflow to zero even without mantissa
// rounding.
{
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
},
// One (to make sure bias-handling is done correctly).
{
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
},
// Values in a space where ties round down due to ties-to-even:
// Value with highest mantissa that rounds down.
{
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
},
// Value with lowest mantissa that rounds up.
{
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
},
// Values in a space where ties round up due to ties-to-even:
// Value with highest mantissa that rounds down.
{
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
},
// Value with a mantissa that rounds up.
{
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
},
// Largest value that does not overflow to infinity.
{
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
},
// Smallest value that overflows to infinity due to mantissa rounding up.
{
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
},
// Smallest value that overflows to infinity, without mantissa rounding.
{
F64VAL(0b10001111111, 0x000000, 0x00000000), // 3.4028236692093846e+38
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
},
// Smallest value that overflows to infinity due to mantissa rounding up,
// even when exponent bits aren't reduced.
{
F64VAL(0b11111111110, 0x7fffff, 0x10000000), // 1.7976930812868855e+308
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000) // Inf
},
// True infinity.
{
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
},
// NAN with a 1 in the preserved bits.
{
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
},
// NAN with a 1 in the truncated bits.
{
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
},
// NAN with all ones, causing rounding overflow.
{
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
},
};
class ReducedPrecisionAccuracyTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<int> {
protected:
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
void DoIt(int exponent_bits, int mantissa_bits,
const Uint (&test_values)[kNumInputs][kNumTestcases],
int operation_index);
};
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionHalf) {
int operation_index = GetParam();
DoIt<Eigen::half, uint16>(f16_exponent_sizes[operation_index],
f16_mantissa_sizes[operation_index],
f16_test_values, operation_index);
}
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionBfloat16) {
int operation_index = GetParam();
DoIt<bfloat16, uint16>(bf16_exponent_sizes[operation_index],
bf16_mantissa_sizes[operation_index], bf16_test_values,
operation_index);
}
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionFloat) {
int operation_index = GetParam();
DoIt<float, uint32>(f32_exponent_sizes[operation_index],
f32_mantissa_sizes[operation_index], f32_test_values,
operation_index);
}
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionDouble) {
int operation_index = GetParam();
DoIt<double, uint64>(f64_exponent_sizes[operation_index],
f64_mantissa_sizes[operation_index], f64_test_values,
operation_index);
}
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
void ReducedPrecisionAccuracyTest::DoIt(
int exponent_bits, int mantissa_bits,
const Uint (&test_values)[kNumInputs][kNumTestcases], int operation_index) {
SCOPED_TRACE(absl::StrFormat("operation_index %d", operation_index));
SCOPED_TRACE(absl::StrFormat("%d exponent bits, %d mantissa bits",
exponent_bits, mantissa_bits));
std::vector<Fp> input_values;
std::vector<Fp> expected_values;
const Uint sign_bit = Uint{1} << (sizeof(Fp) * 8 - 1);
for (const auto& test_value : test_values) {
// Add positive values.
input_values.push_back(absl::bit_cast<float>(test_value[0]));
expected_values.push_back(absl::bit_cast<float>(test_value[index]));
input_values.push_back(absl::bit_cast<Fp>(test_value[0]));
expected_values.push_back(absl::bit_cast<Fp>(test_value[operation_index]));
// Add negative values. We do this in the bitwise representation so as to
// avoid problems with NaN handling.
input_values.push_back(absl::bit_cast<float>(test_value[0] ^ sign_bit));
input_values.push_back(absl::bit_cast<Fp, Uint>(test_value[0] ^ sign_bit));
expected_values.push_back(
absl::bit_cast<float>(test_value[index] ^ sign_bit));
absl::bit_cast<Fp, Uint>(test_value[operation_index] ^ sign_bit));
}
// This is required for proper handling of NaN values.
@ -229,19 +519,18 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
Literal a_literal = LiteralUtil::CreateR1<Fp>({input_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
ReducePrecision(a, exponent_bits, mantissa_bits);
ComputeAndCompareR1<float>(&builder, expected_values, {a_data.get()});
ComputeAndCompareR1<Fp>(&builder, expected_values, {a_data.get()});
}
INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest,
ReducePrecisionAccuracyTest,
::testing::Values(0, 1, 2, 3), TestDataToString);
INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest,
ReducedPrecisionAccuracyTest, ::testing::Range(0, 4));
// Tests to confirm that the compiler optimization functions add the expected
// ReducePrecisionInsertion passes.