[XLA] Support the reduce-precision HLO for all floating-point types.
PiperOrigin-RevId: 258823530
This commit is contained in:
parent
c68ed58053
commit
4bd8619abf
@ -63,28 +63,44 @@ int64 GlobalRandomValue() {
|
||||
return rng();
|
||||
}
|
||||
|
||||
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
||||
int64 mantissa_bits,
|
||||
llvm::IRBuilder<>* b) {
|
||||
StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
|
||||
llvm::Value* x,
|
||||
int64 dest_exponent_bits,
|
||||
int64 dest_mantissa_bits,
|
||||
llvm::IRBuilder<>* b) {
|
||||
using llvm::APInt;
|
||||
|
||||
if (!primitive_util::IsFloatingPointType(src_ty)) {
|
||||
return Unimplemented(
|
||||
"ReducePrecision cannot accept non-floating-point type %s.",
|
||||
PrimitiveType_Name(src_ty));
|
||||
}
|
||||
|
||||
// Integer and float types for casting and constant generation.
|
||||
llvm::Type* float_type = x->getType();
|
||||
llvm::IntegerType* int_type = b->getInt32Ty();
|
||||
int64 nbits = float_type->getPrimitiveSizeInBits();
|
||||
llvm::IntegerType* int_type = b->getIntNTy(nbits);
|
||||
|
||||
// SignificandWidth includes the implicit extra bit.
|
||||
int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
|
||||
int src_exponent_bits = nbits - 1 - src_mantissa_bits;
|
||||
|
||||
// Cast the input value to an integer for bitwise manipulation.
|
||||
llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
|
||||
|
||||
if (mantissa_bits < 23) {
|
||||
if (dest_mantissa_bits < src_mantissa_bits) {
|
||||
// Last remaining mantissa bit.
|
||||
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
|
||||
APInt last_mantissa_bit_mask(nbits, 1);
|
||||
last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
|
||||
|
||||
// Compute rounding bias for round-to-nearest with ties to even. This is
|
||||
// equal to a base value of 0111... plus one bit if the last remaining
|
||||
// mantissa bit is 1.
|
||||
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
||||
APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
|
||||
llvm::Value* x_last_mantissa_bit = b->CreateLShr(
|
||||
b->CreateAnd(x_as_int,
|
||||
llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
|
||||
(23 - mantissa_bits));
|
||||
(src_mantissa_bits - dest_mantissa_bits));
|
||||
llvm::Value* x_rounding_bias =
|
||||
b->CreateAdd(x_last_mantissa_bit,
|
||||
llvm::ConstantInt::get(int_type, base_rounding_bias));
|
||||
@ -93,16 +109,19 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
||||
// where adding the rounding bias overflows into the exponent bits is
|
||||
// correct; the non-masked mantissa bits will all be zero, and the
|
||||
// exponent will be incremented by one.
|
||||
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||
APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||
x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
|
||||
x_as_int = b->CreateAnd(x_as_int,
|
||||
llvm::ConstantInt::get(int_type, truncation_mask));
|
||||
}
|
||||
|
||||
if (exponent_bits < 8) {
|
||||
// Masks for f32 values.
|
||||
const uint32_t f32_sign_bit_mask = 1u << 31;
|
||||
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
||||
if (dest_exponent_bits < src_exponent_bits) {
|
||||
APInt sign_bit_mask(nbits, 1);
|
||||
sign_bit_mask <<= nbits - 1;
|
||||
|
||||
APInt exp_bits_mask(nbits, 1);
|
||||
exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
|
||||
<< src_mantissa_bits;
|
||||
|
||||
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
|
||||
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
|
||||
@ -116,28 +135,31 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
||||
// (2^7-1) - 2^(n-1)-1.
|
||||
//
|
||||
// Note that we have already checked that exponents_bits >= 1.
|
||||
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
||||
const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
|
||||
const uint32_t reduced_max_exponent =
|
||||
f32_exponent_bias + reduced_exponent_bias;
|
||||
const uint32_t reduced_min_exponent =
|
||||
f32_exponent_bias - reduced_exponent_bias;
|
||||
APInt exponent_bias(nbits, 1);
|
||||
exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
|
||||
|
||||
APInt reduced_exponent_bias(nbits, 1);
|
||||
reduced_exponent_bias =
|
||||
(reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
|
||||
|
||||
APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
|
||||
APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
|
||||
|
||||
// Do we overflow or underflow?
|
||||
llvm::Value* x_exponent = b->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
||||
llvm::Value* x_exponent =
|
||||
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
|
||||
llvm::Value* x_overflows = b->CreateICmpUGT(
|
||||
x_exponent,
|
||||
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
|
||||
x_exponent, llvm::ConstantInt::get(
|
||||
int_type, reduced_max_exponent << src_mantissa_bits));
|
||||
llvm::Value* x_underflows = b->CreateICmpULE(
|
||||
x_exponent,
|
||||
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
|
||||
x_exponent, llvm::ConstantInt::get(
|
||||
int_type, reduced_min_exponent << src_mantissa_bits));
|
||||
|
||||
// Compute appropriately-signed values of zero and infinity.
|
||||
llvm::Value* x_signed_zero = b->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
|
||||
llvm::Value* x_signed_zero =
|
||||
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
|
||||
llvm::Value* x_signed_inf = b->CreateOr(
|
||||
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
||||
x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
|
||||
|
||||
// Force to zero or infinity if overflow or underflow. (Note that this
|
||||
// truncates all denormal values to zero, rather than rounding them.)
|
||||
@ -161,7 +183,7 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
||||
if (!b->getFastMathFlags().noNaNs()) {
|
||||
llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
|
||||
|
||||
if (mantissa_bits > 0) {
|
||||
if (dest_mantissa_bits > 0) {
|
||||
result = b->CreateSelect(x_is_nan, x, result);
|
||||
} else {
|
||||
result = b->CreateSelect(
|
||||
@ -171,11 +193,14 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
|
||||
auto reduced_precision = EmitReducePrecisionFloat(
|
||||
f32_value,
|
||||
/*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
||||
/*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
|
||||
StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto reduced_precision,
|
||||
EmitReducePrecisionIR(
|
||||
/*src_ty=*/F32, f32_value,
|
||||
/*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
||||
/*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b));
|
||||
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
|
||||
auto shifted = b->CreateLShr(as_int32, 16);
|
||||
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
|
||||
@ -1099,11 +1124,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
||||
const HloInstruction* hlo, llvm::Value* x) {
|
||||
if (hlo->operand(0)->shape().element_type() != F32) {
|
||||
return Unimplemented("reduce-precision only implemented for F32");
|
||||
}
|
||||
return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
|
||||
/*mantissa_bits=*/hlo->mantissa_bits(), b_);
|
||||
return EmitReducePrecisionIR(
|
||||
/*src_ty=*/hlo->operand(0)->shape().element_type(), x,
|
||||
/*dest_exponent_bits=*/hlo->exponent_bits(),
|
||||
/*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
|
||||
}
|
||||
|
||||
static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
|
||||
@ -1990,7 +2014,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
llvm::Value* float_val =
|
||||
b_->CreateUIToFP(elem_index_linear, float_ir_type);
|
||||
if (component_element_type == BF16) {
|
||||
iota_result = EmitF32ToBF16(float_val, b_);
|
||||
TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
|
||||
} else {
|
||||
iota_result = float_val;
|
||||
}
|
||||
|
@ -67,6 +67,30 @@ T ToArithmeticSafeType(T t) {
|
||||
return std::move(t);
|
||||
}
|
||||
|
||||
// UintWithSize<N> gets an unsigned integer with the given size in bytes.
|
||||
template <size_t Bytes>
|
||||
struct UintWithSize {};
|
||||
|
||||
template <>
|
||||
struct UintWithSize<1> {
|
||||
using type = uint8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UintWithSize<2> {
|
||||
using type = uint16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UintWithSize<4> {
|
||||
using type = uint32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UintWithSize<8> {
|
||||
using type = uint64;
|
||||
};
|
||||
|
||||
// Templated DfsHloVisitor for use by HloEvaluator.
|
||||
//
|
||||
// Typically ReturnT here indicates the resulting literal type of each evaluated
|
||||
@ -2389,48 +2413,57 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleCos<ElementwiseT>(cos);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_same<
|
||||
float, NativeT>::value>::type* = nullptr>
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value>::type* = nullptr>
|
||||
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[reduce_precision],
|
||||
ElementWiseUnaryOp(reduce_precision, [reduce_precision](
|
||||
ElementwiseT elem) {
|
||||
uint32_t value_as_int = absl::bit_cast<uint32_t>(elem);
|
||||
const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
|
||||
const uint32_t exponent_bits = reduce_precision->exponent_bits();
|
||||
ElementWiseUnaryOp(reduce_precision, [&](ElementwiseT elem) {
|
||||
const uint32 src_mantissa_bits =
|
||||
std::numeric_limits<NativeT>::digits - 1;
|
||||
const uint32 src_exponent_bits =
|
||||
8 * sizeof(NativeT) - src_mantissa_bits - 1;
|
||||
const uint32 dest_mantissa_bits = reduce_precision->mantissa_bits();
|
||||
const uint32 dest_exponent_bits = reduce_precision->exponent_bits();
|
||||
|
||||
using Uint = typename UintWithSize<sizeof(NativeT)>::type;
|
||||
Uint value_as_int = absl::bit_cast<Uint>(elem);
|
||||
|
||||
// Code is based on the CPU/GPU implementation in LLVM-emitting code.
|
||||
//
|
||||
// Bits in float type:
|
||||
// Bits in float32 type:
|
||||
// mantissa : bits [0:22]
|
||||
// exponent : bits [23:30]
|
||||
// sign : bits [31]
|
||||
if (mantissa_bits < 23) {
|
||||
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
|
||||
if (dest_mantissa_bits < src_mantissa_bits) {
|
||||
const Uint last_mantissa_bit_mask =
|
||||
Uint{1} << (src_mantissa_bits - dest_mantissa_bits);
|
||||
|
||||
// Compute rounding bias for round-to-nearest with ties to even.
|
||||
// This is equal to a base value of 0111... plus one bit if the last
|
||||
// remaining mantissa bit is 1.
|
||||
const uint32_t base_rounding_bias =
|
||||
(last_mantissa_bit_mask >> 1) - 1;
|
||||
const uint32_t x_last_mantissa_bit =
|
||||
(value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
|
||||
const uint32_t x_rounding_bias =
|
||||
const Uint base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
||||
const Uint x_last_mantissa_bit =
|
||||
(value_as_int & last_mantissa_bit_mask) >>
|
||||
(src_mantissa_bits - dest_mantissa_bits);
|
||||
const Uint x_rounding_bias =
|
||||
x_last_mantissa_bit + base_rounding_bias;
|
||||
|
||||
// Add rounding bias, and mask out truncated bits. Note that the
|
||||
// case where adding the rounding bias overflows into the exponent
|
||||
// bits is correct; the non-masked mantissa bits will all be zero,
|
||||
// and the exponent will be incremented by one.
|
||||
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||
const Uint truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||
value_as_int = value_as_int + x_rounding_bias;
|
||||
value_as_int = value_as_int & truncation_mask;
|
||||
}
|
||||
if (exponent_bits < 8) {
|
||||
if (dest_exponent_bits < src_exponent_bits) {
|
||||
// Masks for f32 values.
|
||||
const uint32_t f32_sign_bit_mask = 1u << 31;
|
||||
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
||||
const Uint sign_bit_mask = Uint{1} << 8 * sizeof(NativeT) - 1;
|
||||
const Uint exp_bits_mask = (Uint{1 << src_exponent_bits} - 1)
|
||||
<< src_mantissa_bits;
|
||||
|
||||
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
|
||||
// most- significant bit -- is equal to 1.0f for all exponent sizes.
|
||||
@ -2444,23 +2477,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// is (2^7-1) - 2^(n-1)-1.
|
||||
//
|
||||
// Note that we have already checked that exponents_bits >= 1.
|
||||
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
||||
const uint32_t reduced_exponent_bias =
|
||||
(1 << (exponent_bits - 1)) - 1;
|
||||
const uint32_t reduced_max_exponent =
|
||||
f32_exponent_bias + reduced_exponent_bias;
|
||||
const uint32_t reduced_min_exponent =
|
||||
f32_exponent_bias - reduced_exponent_bias;
|
||||
const Uint exponent_bias = (Uint{1} << (src_exponent_bits - 1)) - 1;
|
||||
const Uint reduced_exponent_bias =
|
||||
(1 << (dest_exponent_bits - 1)) - 1;
|
||||
const Uint reduced_max_exponent =
|
||||
exponent_bias + reduced_exponent_bias;
|
||||
const Uint reduced_min_exponent =
|
||||
exponent_bias - reduced_exponent_bias;
|
||||
|
||||
// Do we overflow or underflow?
|
||||
const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
|
||||
const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
|
||||
const Uint x_exponent = value_as_int & exp_bits_mask;
|
||||
const bool x_overflows =
|
||||
x_exponent > (reduced_max_exponent << src_mantissa_bits);
|
||||
const bool x_underflows =
|
||||
x_exponent <= (reduced_min_exponent << 23);
|
||||
x_exponent <= (reduced_min_exponent << src_mantissa_bits);
|
||||
|
||||
// Compute appropriately-signed values of zero and infinity.
|
||||
const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
|
||||
const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
|
||||
const Uint x_signed_zero = value_as_int & sign_bit_mask;
|
||||
const Uint x_signed_inf = x_signed_zero | exp_bits_mask;
|
||||
|
||||
// Force to zero or infinity if overflow or underflow. (Note that
|
||||
// this truncates all denormal values to zero, rather than rounding
|
||||
@ -2469,23 +2503,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
value_as_int = x_underflows ? x_signed_zero : value_as_int;
|
||||
}
|
||||
|
||||
float reduced_result = absl::bit_cast<float>(value_as_int);
|
||||
NativeT reduced_result = absl::bit_cast<NativeT>(value_as_int);
|
||||
if (std::isnan(elem)) {
|
||||
reduced_result = mantissa_bits > 0
|
||||
reduced_result = dest_mantissa_bits > 0
|
||||
? elem
|
||||
: std::numeric_limits<float>::infinity();
|
||||
: std::numeric_limits<NativeT>::infinity();
|
||||
}
|
||||
return reduced_result;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_same<
|
||||
double, NativeT>::value>::type* = nullptr>
|
||||
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return InvalidArgument("Double is not supported for reduce precision");
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value ||
|
||||
|
@ -39,189 +39,479 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Tests to confirm that the ReducePrecision operation produces the expected
|
||||
// numerical values.
|
||||
class ReducePrecisionAccuracyTest : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<int> {
|
||||
};
|
||||
|
||||
// For reduction to IEEE-f16, we want to test the following cases, in both
|
||||
// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10
|
||||
// mantissa bits.)
|
||||
// Testcases in this file work as follows.
|
||||
//
|
||||
// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a
|
||||
// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only
|
||||
// variants of IEEE-f16.
|
||||
static const int exponent_sizes[] = {8, 5, 5, 8};
|
||||
static const int mantissa_sizes[] = {23, 10, 23, 10};
|
||||
// for ty in {f16, bf16, f32, f64}:
|
||||
// for (operation_index, (e, m)) in \
|
||||
// enumerate(zip(ty_exponent_sizes, ty_mantissa_sizes)):
|
||||
//
|
||||
// for testcase in ty_test_values:
|
||||
// let expected = testcase[0]
|
||||
// let input = testcase[operation_index]
|
||||
//
|
||||
// CHECK that XLA-reduce-precision(
|
||||
// input, /*exp_bits=*/e, /*mantissa_bits=*/m) == expected
|
||||
//
|
||||
// Put into words:
|
||||
//
|
||||
// - ty_{exponent,mantissa}_sizes tell us the different ways we will reduce the
|
||||
// precision of `ty`.
|
||||
//
|
||||
// - ty_test_values is a 2D array of testcases, each of which is
|
||||
// len(ty_exponent_sizes) elements long. The first element corresponds to
|
||||
// the input, and the j'th element corresponds to the expected output of
|
||||
// doing a reduce-precision with parameters ty_{exponent,mantissa}_sizes[j].
|
||||
//
|
||||
// You'll note above that testcase[0] is reused as both expected and input when
|
||||
// operation_index == 0. This implies that ty_{exponent,mantissa}_sizes[0] must
|
||||
// be equal to `ty`'s exponent/mantissa size, making the reduce-precision op
|
||||
// tested a nop.
|
||||
|
||||
string TestDataToString(const ::testing::TestParamInfo<int> data) {
|
||||
int i = data.param;
|
||||
return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i],
|
||||
"_mantissa_bits");
|
||||
}
|
||||
// We want to test IEEE-f16 (a nop), cases that reduce just the
|
||||
// mantissa/exponent, and a case that reduces both.
|
||||
//
|
||||
// We don't have a lot of tests here, relying instead on the coverage we have of
|
||||
// f32 and f64.
|
||||
//
|
||||
// Note: The hypothetical float(3,7) type we're "converting" to would have:
|
||||
// max exp = 2^(3-1) - 1 = 3
|
||||
// min exp = -max_exp + 1 = -2
|
||||
static const int f16_exponent_sizes[] = {5, 3, 3, 5};
|
||||
static const int f16_mantissa_sizes[] = {10, 7, 10, 7};
|
||||
|
||||
// The FPVAL macro allows us to write out the binary representation of the
|
||||
// input and expected values in a more readable manner. The mantissa bits
|
||||
// are separated into the "high" bits (retained with reduction to IEEE-f16)
|
||||
// and the "low" bits (truncated with reduction to IEEE-f16).
|
||||
#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
||||
((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA))
|
||||
// The F16VAL macro lets us write out the binary representation of an f16 in a
|
||||
// more readable manner, separating out the exponent and mantissa.
|
||||
#define F16VAL(EXPONENT, MANTISSA) ((EXPONENT << 10) + (MANTISSA))
|
||||
|
||||
// Each element in the test-value array consists of four numbers. The first is
|
||||
// the input value and the following are the expected output values for the
|
||||
// various precision-reduction cases.
|
||||
static const uint32_t test_values[][4] = {
|
||||
static const uint16 f16_test_values[][4] = {
|
||||
// True zero.
|
||||
{
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(00000000, 0000000000, 0000000000000) // 0.0
|
||||
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||
},
|
||||
// Largest exponent that underflows to zero.
|
||||
// One.
|
||||
{
|
||||
FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05
|
||||
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||
},
|
||||
// Largest value that rounds to a denormal and thus clamps to zero.
|
||||
// Largest exponent that underflows to zero is -3, which is encoded as
|
||||
// -3 + 15 = 12
|
||||
{
|
||||
FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05
|
||||
F16VAL(0b01100, 0b0000000000), // 2^-3
|
||||
F16VAL(0b00000, 0b0000000000), // 0
|
||||
F16VAL(0b00000, 0b0000000000), // 0
|
||||
F16VAL(0b01100, 0b0000000000), // 2^-3
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||
// up and incrementing the exponent out of the denormal range.
|
||||
{
|
||||
FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05
|
||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
||||
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
|
||||
F16VAL(0b01100, 0b1111111100), // 1020 * 2^-3
|
||||
F16VAL(0b01101, 0b0000000000), // 2^-2
|
||||
F16VAL(0b00000, 0b0000000000), // 0
|
||||
F16VAL(0b01101, 0b0000000000), // 2^-2
|
||||
},
|
||||
};
|
||||
|
||||
// We want to test bfloat16 (a nop), cases that reduce just the
|
||||
// mantissa/exponent, and a case that reduces both.
|
||||
//
|
||||
// We don't have a lot of tests here, relying instead on the coverage we have of
|
||||
// f32 and f64.
|
||||
static const int bf16_exponent_sizes[] = {8, 5, 5, 8};
|
||||
static const int bf16_mantissa_sizes[] = {7, 5, 7, 5};
|
||||
|
||||
// The BF16VAL macro lets us write out the binary representation of a bf16 in a
|
||||
// more readable manner, separating out the exponent and mantissa.
|
||||
#define BF16VAL(EXPONENT, MANTISSA) ((EXPONENT << 7) + (MANTISSA))
|
||||
|
||||
static const uint16 bf16_test_values[][4] = {
|
||||
// True zero.
|
||||
{
|
||||
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||
},
|
||||
// One.
|
||||
{
|
||||
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||
},
|
||||
// Largest exponent that underflows to zero.
|
||||
{
|
||||
BF16VAL(0b01110000, 0b0000000), // 3.05176e-05
|
||||
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||
BF16VAL(0b01110000, 0b0000000) // 3.05176e-05
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||
// up and incrementing the exponent out of the denormal range.
|
||||
{
|
||||
BF16VAL(0b01110000, 0b1111110), // 6.05583e-05
|
||||
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
|
||||
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
|
||||
},
|
||||
};
|
||||
|
||||
// We want to test IEEE-f32 (a no-op), IEEE-f16, and exponent-reduction-only and
|
||||
// mantissa-reduction-only variants of IEEE-f16.
|
||||
static const int f32_exponent_sizes[] = {8, 5, 5, 8};
|
||||
static const int f32_mantissa_sizes[] = {23, 10, 23, 10};
|
||||
|
||||
// The F32VAL macro allows us to write out the binary representation of the
|
||||
// input and expected values in a more readable manner. The mantissa bits
|
||||
// are separated into the "high" bits (retained with reduction to IEEE-f16)
|
||||
// and the "low" bits (truncated with reduction to IEEE-f16).
|
||||
#define F32VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
||||
((EXPONENT << 23) + (HIGH_MANTISSA << 13) + (LOW_MANTISSA))
|
||||
|
||||
static const uint32 f32_test_values[][4] = {
|
||||
// True zero.
|
||||
{
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000) // 0.0
|
||||
},
|
||||
// Largest exponent that underflows to zero.
|
||||
{
|
||||
F32VAL(0b01110000, 0b0000000000, 0b0000000000000), // 3.05176e-05
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b01110000, 0b0000000000, 0b0000000000000) // 3.05176e-05
|
||||
},
|
||||
// Largest value that rounds to a denormal and thus clamps to zero.
|
||||
{
|
||||
F32VAL(0b01110000, 0b1111111111, 0b0111111111111), // 6.10203e-05
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b01110000, 0b1111111111, 0b0000000000000) // 6.10054e-05
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||
// up and incrementing the exponent out of the denormal range.
|
||||
{
|
||||
F32VAL(0b01110000, 0b1111111111, 0b1000000000000), // 6.10203e-05
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero even without mantissa
|
||||
// rounding.
|
||||
{
|
||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
||||
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
|
||||
},
|
||||
// One (to make sure bias-handling is done correctly.
|
||||
// One (to make sure bias-handling is done correctly).
|
||||
{
|
||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
||||
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
|
||||
},
|
||||
// Values in a space where ties round down due to ties-to-even:
|
||||
// Value with highest mantissa that rounds down.
|
||||
{
|
||||
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
|
||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
||||
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
|
||||
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
|
||||
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
|
||||
},
|
||||
// Value with lowest mantissa that rounds up.
|
||||
{
|
||||
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
|
||||
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
|
||||
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
|
||||
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
|
||||
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
|
||||
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
|
||||
},
|
||||
// Values in a space where ties round up due to ties-to-even:
|
||||
// Value with highest mantissa that rounds down.
|
||||
{
|
||||
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
|
||||
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
|
||||
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
|
||||
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
|
||||
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
|
||||
},
|
||||
// Value with a mantissa that rounds up.
|
||||
{
|
||||
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
|
||||
FPVAL(01111111, 0000000010, 0000000000000), // 1.00195
|
||||
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
|
||||
FPVAL(01111111, 0000000010, 0000000000000) // 1.00195
|
||||
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
|
||||
F32VAL(0b01111111, 0b0000000010, 0b0000000000000), // 1.00195
|
||||
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
|
||||
F32VAL(0b01111111, 0b0000000010, 0b0000000000000) // 1.00195
|
||||
},
|
||||
// Largest value that does not overflow to infinity.
|
||||
{
|
||||
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
|
||||
FPVAL(10001110, 1111111111, 0000000000000), // 65504.0
|
||||
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
|
||||
FPVAL(10001110, 1111111111, 0000000000000) // 65504.0
|
||||
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
|
||||
F32VAL(0b10001110, 0b1111111111, 0b0000000000000), // 65504.0
|
||||
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
|
||||
F32VAL(0b10001110, 0b1111111111, 0b0000000000000) // 65504.0
|
||||
},
|
||||
// Smallest value that overflows to infinity due to mantissa rounding up.
|
||||
{
|
||||
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
|
||||
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
|
||||
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
|
||||
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
|
||||
},
|
||||
// Smallest value that overflows to infinity, without mantissa rounding.
|
||||
{
|
||||
FPVAL(10001111, 0000000000, 0000000000000), // 65536.0
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
|
||||
F32VAL(0b10001111, 0b0000000000, 0b0000000000000), // 65536.0
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
|
||||
},
|
||||
// Smallest value that overflows to infinity due to mantissa rounding up,
|
||||
// even when exponent bits aren't reduced.
|
||||
{
|
||||
FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000) // Inf
|
||||
F32VAL(0b11111110, 0b1111111111, 0b1000000000000), // 3.40199e+38
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
|
||||
},
|
||||
// True infinity.
|
||||
{
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
||||
FPVAL(11111111, 0000000000, 0000000000000) // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
|
||||
},
|
||||
// NAN with a 1 in the preserved bits.
|
||||
{
|
||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
||||
FPVAL(11111111, 1000000000, 0000000000000) // NaN
|
||||
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||
F32VAL(0b11111111, 0b1000000000, 0b0000000000000) // NaN
|
||||
},
|
||||
// NAN with a 1 in the truncated bits.
|
||||
{
|
||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
||||
FPVAL(11111111, 0000000000, 0000000000001) // NaN
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||
F32VAL(0b11111111, 0b0000000000, 0b0000000000001) // NaN
|
||||
},
|
||||
// NAN with all ones, causing rounding overflow.
|
||||
{
|
||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
||||
FPVAL(11111111, 1111111111, 1111111111111) // NaN
|
||||
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||
F32VAL(0b11111111, 0b1111111111, 0b1111111111111) // NaN
|
||||
}};
|
||||
|
||||
XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
|
||||
int index = GetParam();
|
||||
int exponent_bits = exponent_sizes[index];
|
||||
int mantissa_bits = mantissa_sizes[index];
|
||||
// F64VAL is like F32VAL but for doubles.
|
||||
//
|
||||
// Here the "high" mantissa bits are those retained with reduction to IEEE-f32
|
||||
// (the first 23 bits), and the "low" bits are those truncated with reduction to
|
||||
// IEEE-f32 (the remaining 29 bits).
|
||||
#define F64VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
||||
((uint64{EXPONENT} << 52) + (uint64{HIGH_MANTISSA} << 29) + \
|
||||
uint64{LOW_MANTISSA})
|
||||
|
||||
std::vector<float> input_values;
|
||||
std::vector<float> expected_values;
|
||||
// We want to test IEEE-f64 (a no-op), IEEE-f32, and exponent-reduction-only and
|
||||
// mantissa-reduction-only variants of IEEE-f32.
|
||||
static const int f64_exponent_sizes[] = {11, 8, 8, 11};
|
||||
static const int f64_mantissa_sizes[] = {52, 23, 52, 23};
|
||||
|
||||
const uint32_t sign_bit = 1u << 31;
|
||||
static const uint64 f64_test_values[][4] = {
|
||||
// True zero.
|
||||
{
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
},
|
||||
// Largest exponent that underflows to zero, namely -127 (encoded as
|
||||
// -127 + 1023).
|
||||
{
|
||||
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
|
||||
},
|
||||
// Largest value that rounds to a denormal and thus clamps to zero.
|
||||
{
|
||||
F64VAL(0b01110000000, 0x7FFFFF, 0x0FFFFFFF), // 1.1754943157898258e-38
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b01110000000, 0x7FFFFF, 0x00000000), // 1.1754942807573643e-38
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||
// up and incrementing the exponent out of the denormal range.
|
||||
{
|
||||
F64VAL(0b01110000000, 0x7FFFFF, 0x10000000), // 1.1754943157898259e-38
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
|
||||
},
|
||||
// Smallest value that doesn't underflow to zero even without mantissa
|
||||
// rounding.
|
||||
{
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
|
||||
},
|
||||
// One (to make sure bias-handling is done correctly).
|
||||
{
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
|
||||
},
|
||||
// Values in a space where ties round down due to ties-to-even:
|
||||
// Value with highest mantissa that rounds down.
|
||||
{
|
||||
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
|
||||
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
|
||||
},
|
||||
// Value with lowest mantissa that rounds up.
|
||||
{
|
||||
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
|
||||
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
|
||||
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
|
||||
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
|
||||
},
|
||||
// Values in a space where ties round up due to ties-to-even:
|
||||
// Value with highest mantissa that rounds down.
|
||||
{
|
||||
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
|
||||
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
|
||||
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
|
||||
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
|
||||
},
|
||||
// Value with a mantissa that rounds up.
|
||||
{
|
||||
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
|
||||
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
|
||||
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
|
||||
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
|
||||
},
|
||||
// Largest value that does not overflow to infinity.
|
||||
{
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
|
||||
},
|
||||
// Smallest value that overflows to infinity due to mantissa rounding up.
|
||||
{
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
|
||||
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
|
||||
},
|
||||
// Smallest value that overflows to infinity, without mantissa rounding.
|
||||
{
|
||||
F64VAL(0b10001111111, 0x000000, 0x00000000), // 3.4028236692093846e+38
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
|
||||
},
|
||||
// Smallest value that overflows to infinity due to mantissa rounding up,
|
||||
// even when exponent bits aren't reduced.
|
||||
{
|
||||
F64VAL(0b11111111110, 0x7fffff, 0x10000000), // 1.7976930812868855e+308
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000) // Inf
|
||||
},
|
||||
// True infinity.
|
||||
{
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||
},
|
||||
// NAN with a 1 in the preserved bits.
|
||||
{
|
||||
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||
},
|
||||
// NAN with a 1 in the truncated bits.
|
||||
{
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||
},
|
||||
// NAN with all ones, causing rounding overflow.
|
||||
{
|
||||
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||
},
|
||||
};
|
||||
|
||||
class ReducedPrecisionAccuracyTest : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<int> {
|
||||
protected:
|
||||
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
|
||||
void DoIt(int exponent_bits, int mantissa_bits,
|
||||
const Uint (&test_values)[kNumInputs][kNumTestcases],
|
||||
int operation_index);
|
||||
};
|
||||
|
||||
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionHalf) {
|
||||
int operation_index = GetParam();
|
||||
DoIt<Eigen::half, uint16>(f16_exponent_sizes[operation_index],
|
||||
f16_mantissa_sizes[operation_index],
|
||||
f16_test_values, operation_index);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionBfloat16) {
|
||||
int operation_index = GetParam();
|
||||
DoIt<bfloat16, uint16>(bf16_exponent_sizes[operation_index],
|
||||
bf16_mantissa_sizes[operation_index], bf16_test_values,
|
||||
operation_index);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionFloat) {
|
||||
int operation_index = GetParam();
|
||||
DoIt<float, uint32>(f32_exponent_sizes[operation_index],
|
||||
f32_mantissa_sizes[operation_index], f32_test_values,
|
||||
operation_index);
|
||||
}
|
||||
|
||||
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionDouble) {
|
||||
int operation_index = GetParam();
|
||||
DoIt<double, uint64>(f64_exponent_sizes[operation_index],
|
||||
f64_mantissa_sizes[operation_index], f64_test_values,
|
||||
operation_index);
|
||||
}
|
||||
|
||||
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
|
||||
void ReducedPrecisionAccuracyTest::DoIt(
|
||||
int exponent_bits, int mantissa_bits,
|
||||
const Uint (&test_values)[kNumInputs][kNumTestcases], int operation_index) {
|
||||
SCOPED_TRACE(absl::StrFormat("operation_index %d", operation_index));
|
||||
SCOPED_TRACE(absl::StrFormat("%d exponent bits, %d mantissa bits",
|
||||
exponent_bits, mantissa_bits));
|
||||
|
||||
std::vector<Fp> input_values;
|
||||
std::vector<Fp> expected_values;
|
||||
|
||||
const Uint sign_bit = Uint{1} << (sizeof(Fp) * 8 - 1);
|
||||
for (const auto& test_value : test_values) {
|
||||
// Add positive values.
|
||||
input_values.push_back(absl::bit_cast<float>(test_value[0]));
|
||||
expected_values.push_back(absl::bit_cast<float>(test_value[index]));
|
||||
input_values.push_back(absl::bit_cast<Fp>(test_value[0]));
|
||||
expected_values.push_back(absl::bit_cast<Fp>(test_value[operation_index]));
|
||||
// Add negative values. We do this in the bitwise representation so as to
|
||||
// avoid problems with NaN handling.
|
||||
input_values.push_back(absl::bit_cast<float>(test_value[0] ^ sign_bit));
|
||||
input_values.push_back(absl::bit_cast<Fp, Uint>(test_value[0] ^ sign_bit));
|
||||
expected_values.push_back(
|
||||
absl::bit_cast<float>(test_value[index] ^ sign_bit));
|
||||
absl::bit_cast<Fp, Uint>(test_value[operation_index] ^ sign_bit));
|
||||
}
|
||||
|
||||
// This is required for proper handling of NaN values.
|
||||
@ -229,19 +519,18 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
|
||||
Literal a_literal = LiteralUtil::CreateR1<Fp>({input_values});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||
|
||||
ReducePrecision(a, exponent_bits, mantissa_bits);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, expected_values, {a_data.get()});
|
||||
ComputeAndCompareR1<Fp>(&builder, expected_values, {a_data.get()});
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest,
|
||||
ReducePrecisionAccuracyTest,
|
||||
::testing::Values(0, 1, 2, 3), TestDataToString);
|
||||
INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest,
|
||||
ReducedPrecisionAccuracyTest, ::testing::Range(0, 4));
|
||||
|
||||
// Tests to confirm that the compiler optimization functions add the expected
|
||||
// ReducePrecisionInsertion passes.
|
||||
|
Loading…
Reference in New Issue
Block a user