diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 30f6c486e1e..a31c1cafdce 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -63,28 +63,44 @@ int64 GlobalRandomValue() { return rng(); } -llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, - int64 mantissa_bits, - llvm::IRBuilder<>* b) { +StatusOr EmitReducePrecisionIR(PrimitiveType src_ty, + llvm::Value* x, + int64 dest_exponent_bits, + int64 dest_mantissa_bits, + llvm::IRBuilder<>* b) { + using llvm::APInt; + + if (!primitive_util::IsFloatingPointType(src_ty)) { + return Unimplemented( + "ReducePrecision cannot accept non-floating-point type %s.", + PrimitiveType_Name(src_ty)); + } + // Integer and float types for casting and constant generation. llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = b->getInt32Ty(); + int64 nbits = float_type->getPrimitiveSizeInBits(); + llvm::IntegerType* int_type = b->getIntNTy(nbits); + + // SignificandWidth includes the implicit extra bit. + int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1; + int src_exponent_bits = nbits - 1 - src_mantissa_bits; // Cast the input value to an integer for bitwise manipulation. llvm::Value* x_as_int = b->CreateBitCast(x, int_type); - if (mantissa_bits < 23) { + if (dest_mantissa_bits < src_mantissa_bits) { // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + APInt last_mantissa_bit_mask(nbits, 1); + last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits; // Compute rounding bias for round-to-nearest with ties to even. This is // equal to a base value of 0111... plus one bit if the last remaining // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1; llvm::Value* x_last_mantissa_bit = b->CreateLShr( b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - mantissa_bits)); + (src_mantissa_bits - dest_mantissa_bits)); llvm::Value* x_rounding_bias = b->CreateAdd(x_last_mantissa_bit, llvm::ConstantInt::get(int_type, base_rounding_bias)); @@ -93,16 +109,19 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // where adding the rounding bias overflows into the exponent bits is // correct; the non-masked mantissa bits will all be zero, and the // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + APInt truncation_mask = ~(last_mantissa_bit_mask - 1); x_as_int = b->CreateAdd(x_as_int, x_rounding_bias); x_as_int = b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; + if (dest_exponent_bits < src_exponent_bits) { + APInt sign_bit_mask(nbits, 1); + sign_bit_mask <<= nbits - 1; + + APInt exp_bits_mask(nbits, 1); + exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1) + << src_mantissa_bits; // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- // significant bit -- is equal to 1.0f for all exponent sizes. Adding @@ -116,28 +135,31 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // (2^7-1) - 2^(n-1)-1. // // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; + APInt exponent_bias(nbits, 1); + exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1; + + APInt reduced_exponent_bias(nbits, 1); + reduced_exponent_bias = + (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1; + + APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias; + APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias; // Do we overflow or underflow? - llvm::Value* x_exponent = b->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_exponent = + b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask)); llvm::Value* x_overflows = b->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + x_exponent, llvm::ConstantInt::get( + int_type, reduced_max_exponent << src_mantissa_bits)); llvm::Value* x_underflows = b->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + x_exponent, llvm::ConstantInt::get( + int_type, reduced_min_exponent << src_mantissa_bits)); // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = b->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_zero = + b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask)); llvm::Value* x_signed_inf = b->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask)); // Force to zero or infinity if overflow or underflow. (Note that this // truncates all denormal values to zero, rather than rounding them.) @@ -161,7 +183,7 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, if (!b->getFastMathFlags().noNaNs()) { llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x); - if (mantissa_bits > 0) { + if (dest_mantissa_bits > 0) { result = b->CreateSelect(x_is_nan, x, result); } else { result = b->CreateSelect( @@ -171,11 +193,14 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, return result; } -llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) { - auto reduced_precision = EmitReducePrecisionFloat( - f32_value, - /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, - /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b); +StatusOr EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* b) { + TF_ASSIGN_OR_RETURN( + auto reduced_precision, + EmitReducePrecisionIR( + /*src_ty=*/F32, f32_value, + /*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b)); auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty()); auto shifted = b->CreateLShr(as_int32, 16); auto truncated = b->CreateTrunc(shifted, b->getInt16Ty()); @@ -1099,11 +1124,10 @@ StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, StatusOr ElementalIrEmitter::EmitReducePrecision( const HloInstruction* hlo, llvm::Value* x) { - if (hlo->operand(0)->shape().element_type() != F32) { - return Unimplemented("reduce-precision only implemented for F32"); - } - return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), - /*mantissa_bits=*/hlo->mantissa_bits(), b_); + return EmitReducePrecisionIR( + /*src_ty=*/hlo->operand(0)->shape().element_type(), x, + /*dest_exponent_bits=*/hlo->exponent_bits(), + /*dest_mantissa_bits=*/hlo->mantissa_bits(), b_); } static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, @@ -1990,7 +2014,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( llvm::Value* float_val = b_->CreateUIToFP(elem_index_linear, float_ir_type); if (component_element_type == BF16) { - iota_result = EmitF32ToBF16(float_val, b_); + TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_)); } else { iota_result = float_val; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 9774ff51283..9fcc6274866 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -67,6 +67,30 @@ T ToArithmeticSafeType(T t) { return std::move(t); } +// UintWithSize gets an unsigned integer with the given size in bytes. +template +struct UintWithSize {}; + +template <> +struct UintWithSize<1> { + using type = uint8; +}; + +template <> +struct UintWithSize<2> { + using type = uint16; +}; + +template <> +struct UintWithSize<4> { + using type = uint32; +}; + +template <> +struct UintWithSize<8> { + using type = uint64; +}; + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -2389,48 +2413,57 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleCos(cos); } - template ::value>::type* = nullptr> + template ::value || + std::is_same::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = absl::bit_cast(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); + ElementWiseUnaryOp(reduce_precision, [&](ElementwiseT elem) { + const uint32 src_mantissa_bits = + std::numeric_limits::digits - 1; + const uint32 src_exponent_bits = + 8 * sizeof(NativeT) - src_mantissa_bits - 1; + const uint32 dest_mantissa_bits = reduce_precision->mantissa_bits(); + const uint32 dest_exponent_bits = reduce_precision->exponent_bits(); + + using Uint = typename UintWithSize::type; + Uint value_as_int = absl::bit_cast(elem); // Code is based on the CPU/GPU implementation in LLVM-emitting code. // - // Bits in float type: + // Bits in float32 type: // mantissa : bits [0:22] // exponent : bits [23:30] // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + if (dest_mantissa_bits < src_mantissa_bits) { + const Uint last_mantissa_bit_mask = + Uint{1} << (src_mantissa_bits - dest_mantissa_bits); // Compute rounding bias for round-to-nearest with ties to even. // This is equal to a base value of 0111... plus one bit if the last // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = + const Uint base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + const Uint x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> + (src_mantissa_bits - dest_mantissa_bits); + const Uint x_rounding_bias = x_last_mantissa_bit + base_rounding_bias; // Add rounding bias, and mask out truncated bits. Note that the // case where adding the rounding bias overflows into the exponent // bits is correct; the non-masked mantissa bits will all be zero, // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + const Uint truncation_mask = ~(last_mantissa_bit_mask - 1); value_as_int = value_as_int + x_rounding_bias; value_as_int = value_as_int & truncation_mask; } - if (exponent_bits < 8) { + if (dest_exponent_bits < src_exponent_bits) { // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; + const Uint sign_bit_mask = Uint{1} << 8 * sizeof(NativeT) - 1; + const Uint exp_bits_mask = (Uint{1 << src_exponent_bits} - 1) + << src_mantissa_bits; // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the // most- significant bit -- is equal to 1.0f for all exponent sizes. @@ -2444,23 +2477,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // is (2^7-1) - 2^(n-1)-1. // // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; + const Uint exponent_bias = (Uint{1} << (src_exponent_bits - 1)) - 1; + const Uint reduced_exponent_bias = + (1 << (dest_exponent_bits - 1)) - 1; + const Uint reduced_max_exponent = + exponent_bias + reduced_exponent_bias; + const Uint reduced_min_exponent = + exponent_bias - reduced_exponent_bias; // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const Uint x_exponent = value_as_int & exp_bits_mask; + const bool x_overflows = + x_exponent > (reduced_max_exponent << src_mantissa_bits); const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); + x_exponent <= (reduced_min_exponent << src_mantissa_bits); // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + const Uint x_signed_zero = value_as_int & sign_bit_mask; + const Uint x_signed_inf = x_signed_zero | exp_bits_mask; // Force to zero or infinity if overflow or underflow. (Note that // this truncates all denormal values to zero, rather than rounding @@ -2469,23 +2503,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { value_as_int = x_underflows ? x_signed_zero : value_as_int; } - float reduced_result = absl::bit_cast(value_as_int); + NativeT reduced_result = absl::bit_cast(value_as_int); if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 + reduced_result = dest_mantissa_bits > 0 ? elem - : std::numeric_limits::infinity(); + : std::numeric_limits::infinity(); } return reduced_result; })); return Status::OK(); } - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double is not supported for reduce precision"); - } - template < typename NativeT, typename std::enable_if::value || diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index e2cf4c0be28..b5d907fd535 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -39,189 +39,479 @@ limitations under the License. namespace xla { namespace { -// Tests to confirm that the ReducePrecision operation produces the expected -// numerical values. -class ReducePrecisionAccuracyTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface { -}; - -// For reduction to IEEE-f16, we want to test the following cases, in both -// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10 -// mantissa bits.) +// Testcases in this file work as follows. // -// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a -// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only -// variants of IEEE-f16. -static const int exponent_sizes[] = {8, 5, 5, 8}; -static const int mantissa_sizes[] = {23, 10, 23, 10}; +// for ty in {f16, bf16, f32, f64}: +// for (operation_index, (e, m)) in \ +// enumerate(zip(ty_exponent_sizes, ty_mantissa_sizes)): +// +// for testcase in ty_test_values: +// let expected = testcase[0] +// let input = testcase[operation_index] +// +// CHECK that XLA-reduce-precision( +// input, /*exp_bits=*/e, /*mantissa_bits=*/m) == expected +// +// Put into words: +// +// - ty_{exponent,mantissa}_sizes tell us the different ways we will reduce the +// precision of `ty`. +// +// - ty_test_values is a 2D array of testcases, each of which is +// len(ty_exponent_sizes) elements long. The first element corresponds to +// the input, and the j'th element corresponds to the expected output of +// doing a reduce-precision with parameters ty_{exponent,mantissa}_sizes[j]. +// +// You'll note above that testcase[0] is reused as both expected and input when +// operation_index == 0. This implies that ty_{exponent,mantissa}_sizes[0] must +// be equal to `ty`'s exponent/mantissa size, making the reduce-precision op +// tested a nop. -string TestDataToString(const ::testing::TestParamInfo data) { - int i = data.param; - return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i], - "_mantissa_bits"); -} +// We want to test IEEE-f16 (a nop), cases that reduce just the +// mantissa/exponent, and a case that reduces both. +// +// We don't have a lot of tests here, relying instead on the coverage we have of +// f32 and f64. +// +// Note: The hypothetical float(3,7) type we're "converting" to would have: +// max exp = 2^(3-1) - 1 = 3 +// min exp = -max_exp + 1 = -2 +static const int f16_exponent_sizes[] = {5, 3, 3, 5}; +static const int f16_mantissa_sizes[] = {10, 7, 10, 7}; -// The FPVAL macro allows us to write out the binary representation of the -// input and expected values in a more readable manner. The mantissa bits -// are separated into the "high" bits (retained with reduction to IEEE-f16) -// and the "low" bits (truncated with reduction to IEEE-f16). -#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \ - ((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA)) +// The F16VAL macro lets us write out the binary representation of an f16 in a +// more readable manner, separating out the exponent and mantissa. +#define F16VAL(EXPONENT, MANTISSA) ((EXPONENT << 10) + (MANTISSA)) -// Each element in the test-value array consists of four numbers. The first is -// the input value and the following are the expected output values for the -// various precision-reduction cases. -static const uint32_t test_values[][4] = { +static const uint16 f16_test_values[][4] = { // True zero. { - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(00000000, 0000000000, 0000000000000) // 0.0 + F16VAL(0b00000, 0b0000000000), // 0.0 + F16VAL(0b00000, 0b0000000000), // 0.0 + F16VAL(0b00000, 0b0000000000), // 0.0 + F16VAL(0b00000, 0b0000000000), // 0.0 }, - // Largest exponent that underflows to zero. + // One. { - FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05 + F16VAL(0b01111, 0b0000000000), // 1.0 + F16VAL(0b01111, 0b0000000000), // 1.0 + F16VAL(0b01111, 0b0000000000), // 1.0 + F16VAL(0b01111, 0b0000000000), // 1.0 }, - // Largest value that rounds to a denormal and thus clamps to zero. + // Largest exponent that underflows to zero is -3, which is encoded as + // -3 + 15 = 12 { - FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05 + F16VAL(0b01100, 0b0000000000), // 2^-3 + F16VAL(0b00000, 0b0000000000), // 0 + F16VAL(0b00000, 0b0000000000), // 0 + F16VAL(0b01100, 0b0000000000), // 2^-3 }, // Smallest value that doesn't underflow to zero, due to mantissa rounding // up and incrementing the exponent out of the denormal range. { - FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05 - FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 - FPVAL(00000000, 0000000000, 0000000000000), // 0.0 - FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + F16VAL(0b01100, 0b1111111100), // 1020 * 2^-3 + F16VAL(0b01101, 0b0000000000), // 2^-2 + F16VAL(0b00000, 0b0000000000), // 0 + F16VAL(0b01101, 0b0000000000), // 2^-2 + }, +}; + +// We want to test bfloat16 (a nop), cases that reduce just the +// mantissa/exponent, and a case that reduces both. +// +// We don't have a lot of tests here, relying instead on the coverage we have of +// f32 and f64. +static const int bf16_exponent_sizes[] = {8, 5, 5, 8}; +static const int bf16_mantissa_sizes[] = {7, 5, 7, 5}; + +// The BF16VAL macro lets us write out the binary representation of a bf16 in a +// more readable manner, separating out the exponent and mantissa. +#define BF16VAL(EXPONENT, MANTISSA) ((EXPONENT << 7) + (MANTISSA)) + +static const uint16 bf16_test_values[][4] = { + // True zero. + { + BF16VAL(0b00000, 0b0000000000), // 0.0 + BF16VAL(0b00000, 0b0000000000), // 0.0 + BF16VAL(0b00000, 0b0000000000), // 0.0 + BF16VAL(0b00000, 0b0000000000), // 0.0 + }, + // One. + { + BF16VAL(0b01111111, 0b0000000), // 1.0 + BF16VAL(0b01111111, 0b0000000), // 1.0 + BF16VAL(0b01111111, 0b0000000), // 1.0 + BF16VAL(0b01111111, 0b0000000), // 1.0 + }, + // Largest exponent that underflows to zero. + { + BF16VAL(0b01110000, 0b0000000), // 3.05176e-05 + BF16VAL(0b00000000, 0b0000000), // 0.0 + BF16VAL(0b00000000, 0b0000000), // 0.0 + BF16VAL(0b01110000, 0b0000000) // 3.05176e-05 + }, + // Smallest value that doesn't underflow to zero, due to mantissa rounding + // up and incrementing the exponent out of the denormal range. + { + BF16VAL(0b01110000, 0b1111110), // 6.05583e-05 + BF16VAL(0b01110001, 0b0000000), // 6.10352e-05 + BF16VAL(0b00000000, 0b0000000), // 0.0 + BF16VAL(0b01110001, 0b0000000), // 6.10352e-05 + }, +}; + +// We want to test IEEE-f32 (a no-op), IEEE-f16, and exponent-reduction-only and +// mantissa-reduction-only variants of IEEE-f16. +static const int f32_exponent_sizes[] = {8, 5, 5, 8}; +static const int f32_mantissa_sizes[] = {23, 10, 23, 10}; + +// The F32VAL macro allows us to write out the binary representation of the +// input and expected values in a more readable manner. The mantissa bits +// are separated into the "high" bits (retained with reduction to IEEE-f16) +// and the "low" bits (truncated with reduction to IEEE-f16). +#define F32VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \ + ((EXPONENT << 23) + (HIGH_MANTISSA << 13) + (LOW_MANTISSA)) + +static const uint32 f32_test_values[][4] = { + // True zero. + { + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000) // 0.0 + }, + // Largest exponent that underflows to zero. + { + F32VAL(0b01110000, 0b0000000000, 0b0000000000000), // 3.05176e-05 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b01110000, 0b0000000000, 0b0000000000000) // 3.05176e-05 + }, + // Largest value that rounds to a denormal and thus clamps to zero. + { + F32VAL(0b01110000, 0b1111111111, 0b0111111111111), // 6.10203e-05 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b01110000, 0b1111111111, 0b0000000000000) // 6.10054e-05 + }, + // Smallest value that doesn't underflow to zero, due to mantissa rounding + // up and incrementing the exponent out of the denormal range. + { + F32VAL(0b01110000, 0b1111111111, 0b1000000000000), // 6.10203e-05 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05 + F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05 }, // Smallest value that doesn't underflow to zero even without mantissa // rounding. { - FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 - FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 - FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 - FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05 + F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05 }, - // One (to make sure bias-handling is done correctly. + // One (to make sure bias-handling is done correctly). { - FPVAL(01111111, 0000000000, 0000000000000), // 1.0 - FPVAL(01111111, 0000000000, 0000000000000), // 1.0 - FPVAL(01111111, 0000000000, 0000000000000), // 1.0 - FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0 }, // Values in a space where ties round down due to ties-to-even: // Value with highest mantissa that rounds down. { - FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 - FPVAL(01111111, 0000000000, 0000000000000), // 1.0 - FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 - FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0 + F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049 + F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0 }, // Value with lowest mantissa that rounds up. { - FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 - FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 - FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 - FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049 + F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098 + F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049 + F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098 }, // Values in a space where ties round up due to ties-to-even: // Value with highest mantissa that rounds down. { - FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 - FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 - FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 - FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146 + F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098 + F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146 + F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098 }, // Value with a mantissa that rounds up. { - FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 - FPVAL(01111111, 0000000010, 0000000000000), // 1.00195 - FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 - FPVAL(01111111, 0000000010, 0000000000000) // 1.00195 + F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146 + F32VAL(0b01111111, 0b0000000010, 0b0000000000000), // 1.00195 + F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146 + F32VAL(0b01111111, 0b0000000010, 0b0000000000000) // 1.00195 }, // Largest value that does not overflow to infinity. { - FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 - FPVAL(10001110, 1111111111, 0000000000000), // 65504.0 - FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 - FPVAL(10001110, 1111111111, 0000000000000) // 65504.0 + F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0 + F32VAL(0b10001110, 0b1111111111, 0b0000000000000), // 65504.0 + F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0 + F32VAL(0b10001110, 0b1111111111, 0b0000000000000) // 65504.0 }, // Smallest value that overflows to infinity due to mantissa rounding up. { - FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 - FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0 + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0 + F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0 }, // Smallest value that overflows to infinity, without mantissa rounding. { - FPVAL(10001111, 0000000000, 0000000000000), // 65536.0 - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + F32VAL(0b10001111, 0b0000000000, 0b0000000000000), // 65536.0 + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0 }, // Smallest value that overflows to infinity due to mantissa rounding up, // even when exponent bits aren't reduced. { - FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38 - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000) // Inf + F32VAL(0b11111110, 0b1111111111, 0b1000000000000), // 3.40199e+38 + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf }, // True infinity. { - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000), // Inf - FPVAL(11111111, 0000000000, 0000000000000) // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf + F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf }, // NAN with a 1 in the preserved bits. { - FPVAL(11111111, 1000000000, 0000000000000), // NaN - FPVAL(11111111, 1000000000, 0000000000000), // NaN - FPVAL(11111111, 1000000000, 0000000000000), // NaN - FPVAL(11111111, 1000000000, 0000000000000) // NaN + F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN + F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN + F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN + F32VAL(0b11111111, 0b1000000000, 0b0000000000000) // NaN }, // NAN with a 1 in the truncated bits. { - FPVAL(11111111, 0000000000, 0000000000001), // NaN - FPVAL(11111111, 0000000000, 0000000000001), // NaN - FPVAL(11111111, 0000000000, 0000000000001), // NaN - FPVAL(11111111, 0000000000, 0000000000001) // NaN + F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN + F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN + F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN + F32VAL(0b11111111, 0b0000000000, 0b0000000000001) // NaN }, // NAN with all ones, causing rounding overflow. { - FPVAL(11111111, 1111111111, 1111111111111), // NaN - FPVAL(11111111, 1111111111, 1111111111111), // NaN - FPVAL(11111111, 1111111111, 1111111111111), // NaN - FPVAL(11111111, 1111111111, 1111111111111) // NaN + F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN + F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN + F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN + F32VAL(0b11111111, 0b1111111111, 0b1111111111111) // NaN }}; -XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { - int index = GetParam(); - int exponent_bits = exponent_sizes[index]; - int mantissa_bits = mantissa_sizes[index]; +// F64VAL is like F32VAL but for doubles. +// +// Here the "high" mantissa bits are those retained with reduction to IEEE-f32 +// (the first 23 bits), and the "low" bits are those truncated with reduction to +// IEEE-f32 (the remaining 29 bits). +#define F64VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \ + ((uint64{EXPONENT} << 52) + (uint64{HIGH_MANTISSA} << 29) + \ + uint64{LOW_MANTISSA}) - std::vector input_values; - std::vector expected_values; +// We want to test IEEE-f64 (a no-op), IEEE-f32, and exponent-reduction-only and +// mantissa-reduction-only variants of IEEE-f32. +static const int f64_exponent_sizes[] = {11, 8, 8, 11}; +static const int f64_mantissa_sizes[] = {52, 23, 52, 23}; - const uint32_t sign_bit = 1u << 31; +static const uint64 f64_test_values[][4] = { + // True zero. + { + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + }, + // Largest exponent that underflows to zero, namely -127 (encoded as + // -127 + 1023). + { + F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39 + }, + // Largest value that rounds to a denormal and thus clamps to zero. + { + F64VAL(0b01110000000, 0x7FFFFF, 0x0FFFFFFF), // 1.1754943157898258e-38 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b01110000000, 0x7FFFFF, 0x00000000), // 1.1754942807573643e-38 + }, + // Smallest value that doesn't underflow to zero, due to mantissa rounding + // up and incrementing the exponent out of the denormal range. + { + F64VAL(0b01110000000, 0x7FFFFF, 0x10000000), // 1.1754943157898259e-38 + F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38 + F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0 + F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38 + }, + // Smallest value that doesn't underflow to zero even without mantissa + // rounding. + { + F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38 + F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38 + F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38 + F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38 + }, + // One (to make sure bias-handling is done correctly). + { + F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0 + F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0 + F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0 + F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0 + }, + // Values in a space where ties round down due to ties-to-even: + // Value with highest mantissa that rounds down. + { + F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448 + F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0 + F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448 + F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0 + }, + // Value with lowest mantissa that rounds up. + { + F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645 + F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896 + F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645 + F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896 + }, + // Values in a space where ties round up due to ties-to-even: + // Value with highest mantissa that rounds down. + { + F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341 + F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896 + F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341 + F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896 + }, + // Value with a mantissa that rounds up. + { + F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343 + F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791 + F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343 + F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791 + }, + // Largest value that does not overflow to infinity. + { + F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38 + F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38 + F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38 + F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38 + }, + // Smallest value that overflows to infinity due to mantissa rounding up. + { + F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38 + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38 + F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38 + }, + // Smallest value that overflows to infinity, without mantissa rounding. + { + F64VAL(0b10001111111, 0x000000, 0x00000000), // 3.4028236692093846e+38 + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38 + }, + // Smallest value that overflows to infinity due to mantissa rounding up, + // even when exponent bits aren't reduced. + { + F64VAL(0b11111111110, 0x7fffff, 0x10000000), // 1.7976930812868855e+308 + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000) // Inf + }, + // True infinity. + { + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf + }, + // NAN with a 1 in the preserved bits. + { + F64VAL(0b11111111111, 0x800000, 0x00000000), // -0 + F64VAL(0b11111111111, 0x800000, 0x00000000), // -0 + F64VAL(0b11111111111, 0x800000, 0x00000000), // -0 + F64VAL(0b11111111111, 0x800000, 0x00000000), // -0 + }, + // NAN with a 1 in the truncated bits. + { + F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN + F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN + F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN + F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN + }, + // NAN with all ones, causing rounding overflow. + { + F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN + F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN + F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN + F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN + }, +}; + +class ReducedPrecisionAccuracyTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + template + void DoIt(int exponent_bits, int mantissa_bits, + const Uint (&test_values)[kNumInputs][kNumTestcases], + int operation_index); +}; + +XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionHalf) { + int operation_index = GetParam(); + DoIt(f16_exponent_sizes[operation_index], + f16_mantissa_sizes[operation_index], + f16_test_values, operation_index); +} + +XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionBfloat16) { + int operation_index = GetParam(); + DoIt(bf16_exponent_sizes[operation_index], + bf16_mantissa_sizes[operation_index], bf16_test_values, + operation_index); +} + +XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionFloat) { + int operation_index = GetParam(); + DoIt(f32_exponent_sizes[operation_index], + f32_mantissa_sizes[operation_index], f32_test_values, + operation_index); +} + +XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionDouble) { + int operation_index = GetParam(); + DoIt(f64_exponent_sizes[operation_index], + f64_mantissa_sizes[operation_index], f64_test_values, + operation_index); +} + +template +void ReducedPrecisionAccuracyTest::DoIt( + int exponent_bits, int mantissa_bits, + const Uint (&test_values)[kNumInputs][kNumTestcases], int operation_index) { + SCOPED_TRACE(absl::StrFormat("operation_index %d", operation_index)); + SCOPED_TRACE(absl::StrFormat("%d exponent bits, %d mantissa bits", + exponent_bits, mantissa_bits)); + + std::vector input_values; + std::vector expected_values; + + const Uint sign_bit = Uint{1} << (sizeof(Fp) * 8 - 1); for (const auto& test_value : test_values) { // Add positive values. - input_values.push_back(absl::bit_cast(test_value[0])); - expected_values.push_back(absl::bit_cast(test_value[index])); + input_values.push_back(absl::bit_cast(test_value[0])); + expected_values.push_back(absl::bit_cast(test_value[operation_index])); // Add negative values. We do this in the bitwise representation so as to // avoid problems with NaN handling. - input_values.push_back(absl::bit_cast(test_value[0] ^ sign_bit)); + input_values.push_back(absl::bit_cast(test_value[0] ^ sign_bit)); expected_values.push_back( - absl::bit_cast(test_value[index] ^ sign_bit)); + absl::bit_cast(test_value[operation_index] ^ sign_bit)); } // This is required for proper handling of NaN values. @@ -229,19 +519,18 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); - ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); + ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } -INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest, - ReducePrecisionAccuracyTest, - ::testing::Values(0, 1, 2, 3), TestDataToString); +INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest, + ReducedPrecisionAccuracyTest, ::testing::Range(0, 4)); // Tests to confirm that the compiler optimization functions add the expected // ReducePrecisionInsertion passes.