[XLA] Support the reduce-precision HLO for all floating-point types.
PiperOrigin-RevId: 258823530
This commit is contained in:
parent
c68ed58053
commit
4bd8619abf
@ -63,28 +63,44 @@ int64 GlobalRandomValue() {
|
|||||||
return rng();
|
return rng();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
|
||||||
int64 mantissa_bits,
|
llvm::Value* x,
|
||||||
|
int64 dest_exponent_bits,
|
||||||
|
int64 dest_mantissa_bits,
|
||||||
llvm::IRBuilder<>* b) {
|
llvm::IRBuilder<>* b) {
|
||||||
|
using llvm::APInt;
|
||||||
|
|
||||||
|
if (!primitive_util::IsFloatingPointType(src_ty)) {
|
||||||
|
return Unimplemented(
|
||||||
|
"ReducePrecision cannot accept non-floating-point type %s.",
|
||||||
|
PrimitiveType_Name(src_ty));
|
||||||
|
}
|
||||||
|
|
||||||
// Integer and float types for casting and constant generation.
|
// Integer and float types for casting and constant generation.
|
||||||
llvm::Type* float_type = x->getType();
|
llvm::Type* float_type = x->getType();
|
||||||
llvm::IntegerType* int_type = b->getInt32Ty();
|
int64 nbits = float_type->getPrimitiveSizeInBits();
|
||||||
|
llvm::IntegerType* int_type = b->getIntNTy(nbits);
|
||||||
|
|
||||||
|
// SignificandWidth includes the implicit extra bit.
|
||||||
|
int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
|
||||||
|
int src_exponent_bits = nbits - 1 - src_mantissa_bits;
|
||||||
|
|
||||||
// Cast the input value to an integer for bitwise manipulation.
|
// Cast the input value to an integer for bitwise manipulation.
|
||||||
llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
|
llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
|
||||||
|
|
||||||
if (mantissa_bits < 23) {
|
if (dest_mantissa_bits < src_mantissa_bits) {
|
||||||
// Last remaining mantissa bit.
|
// Last remaining mantissa bit.
|
||||||
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
|
APInt last_mantissa_bit_mask(nbits, 1);
|
||||||
|
last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
|
||||||
|
|
||||||
// Compute rounding bias for round-to-nearest with ties to even. This is
|
// Compute rounding bias for round-to-nearest with ties to even. This is
|
||||||
// equal to a base value of 0111... plus one bit if the last remaining
|
// equal to a base value of 0111... plus one bit if the last remaining
|
||||||
// mantissa bit is 1.
|
// mantissa bit is 1.
|
||||||
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
|
||||||
llvm::Value* x_last_mantissa_bit = b->CreateLShr(
|
llvm::Value* x_last_mantissa_bit = b->CreateLShr(
|
||||||
b->CreateAnd(x_as_int,
|
b->CreateAnd(x_as_int,
|
||||||
llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
|
llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
|
||||||
(23 - mantissa_bits));
|
(src_mantissa_bits - dest_mantissa_bits));
|
||||||
llvm::Value* x_rounding_bias =
|
llvm::Value* x_rounding_bias =
|
||||||
b->CreateAdd(x_last_mantissa_bit,
|
b->CreateAdd(x_last_mantissa_bit,
|
||||||
llvm::ConstantInt::get(int_type, base_rounding_bias));
|
llvm::ConstantInt::get(int_type, base_rounding_bias));
|
||||||
@ -93,16 +109,19 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
|||||||
// where adding the rounding bias overflows into the exponent bits is
|
// where adding the rounding bias overflows into the exponent bits is
|
||||||
// correct; the non-masked mantissa bits will all be zero, and the
|
// correct; the non-masked mantissa bits will all be zero, and the
|
||||||
// exponent will be incremented by one.
|
// exponent will be incremented by one.
|
||||||
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||||
x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
|
x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
|
||||||
x_as_int = b->CreateAnd(x_as_int,
|
x_as_int = b->CreateAnd(x_as_int,
|
||||||
llvm::ConstantInt::get(int_type, truncation_mask));
|
llvm::ConstantInt::get(int_type, truncation_mask));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (exponent_bits < 8) {
|
if (dest_exponent_bits < src_exponent_bits) {
|
||||||
// Masks for f32 values.
|
APInt sign_bit_mask(nbits, 1);
|
||||||
const uint32_t f32_sign_bit_mask = 1u << 31;
|
sign_bit_mask <<= nbits - 1;
|
||||||
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
|
||||||
|
APInt exp_bits_mask(nbits, 1);
|
||||||
|
exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
|
||||||
|
<< src_mantissa_bits;
|
||||||
|
|
||||||
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
|
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
|
||||||
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
|
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
|
||||||
@ -116,28 +135,31 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
|||||||
// (2^7-1) - 2^(n-1)-1.
|
// (2^7-1) - 2^(n-1)-1.
|
||||||
//
|
//
|
||||||
// Note that we have already checked that exponents_bits >= 1.
|
// Note that we have already checked that exponents_bits >= 1.
|
||||||
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
APInt exponent_bias(nbits, 1);
|
||||||
const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
|
exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
|
||||||
const uint32_t reduced_max_exponent =
|
|
||||||
f32_exponent_bias + reduced_exponent_bias;
|
APInt reduced_exponent_bias(nbits, 1);
|
||||||
const uint32_t reduced_min_exponent =
|
reduced_exponent_bias =
|
||||||
f32_exponent_bias - reduced_exponent_bias;
|
(reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
|
||||||
|
|
||||||
|
APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
|
||||||
|
APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
|
||||||
|
|
||||||
// Do we overflow or underflow?
|
// Do we overflow or underflow?
|
||||||
llvm::Value* x_exponent = b->CreateAnd(
|
llvm::Value* x_exponent =
|
||||||
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
|
||||||
llvm::Value* x_overflows = b->CreateICmpUGT(
|
llvm::Value* x_overflows = b->CreateICmpUGT(
|
||||||
x_exponent,
|
x_exponent, llvm::ConstantInt::get(
|
||||||
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
|
int_type, reduced_max_exponent << src_mantissa_bits));
|
||||||
llvm::Value* x_underflows = b->CreateICmpULE(
|
llvm::Value* x_underflows = b->CreateICmpULE(
|
||||||
x_exponent,
|
x_exponent, llvm::ConstantInt::get(
|
||||||
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
|
int_type, reduced_min_exponent << src_mantissa_bits));
|
||||||
|
|
||||||
// Compute appropriately-signed values of zero and infinity.
|
// Compute appropriately-signed values of zero and infinity.
|
||||||
llvm::Value* x_signed_zero = b->CreateAnd(
|
llvm::Value* x_signed_zero =
|
||||||
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
|
b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
|
||||||
llvm::Value* x_signed_inf = b->CreateOr(
|
llvm::Value* x_signed_inf = b->CreateOr(
|
||||||
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
|
||||||
|
|
||||||
// Force to zero or infinity if overflow or underflow. (Note that this
|
// Force to zero or infinity if overflow or underflow. (Note that this
|
||||||
// truncates all denormal values to zero, rather than rounding them.)
|
// truncates all denormal values to zero, rather than rounding them.)
|
||||||
@ -161,7 +183,7 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
|||||||
if (!b->getFastMathFlags().noNaNs()) {
|
if (!b->getFastMathFlags().noNaNs()) {
|
||||||
llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
|
llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
|
||||||
|
|
||||||
if (mantissa_bits > 0) {
|
if (dest_mantissa_bits > 0) {
|
||||||
result = b->CreateSelect(x_is_nan, x, result);
|
result = b->CreateSelect(x_is_nan, x, result);
|
||||||
} else {
|
} else {
|
||||||
result = b->CreateSelect(
|
result = b->CreateSelect(
|
||||||
@ -171,11 +193,14 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
|
StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
|
||||||
auto reduced_precision = EmitReducePrecisionFloat(
|
llvm::IRBuilder<>* b) {
|
||||||
f32_value,
|
TF_ASSIGN_OR_RETURN(
|
||||||
/*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
auto reduced_precision,
|
||||||
/*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
|
EmitReducePrecisionIR(
|
||||||
|
/*src_ty=*/F32, f32_value,
|
||||||
|
/*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
||||||
|
/*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b));
|
||||||
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
|
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
|
||||||
auto shifted = b->CreateLShr(as_int32, 16);
|
auto shifted = b->CreateLShr(as_int32, 16);
|
||||||
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
|
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
|
||||||
@ -1099,11 +1124,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
||||||
const HloInstruction* hlo, llvm::Value* x) {
|
const HloInstruction* hlo, llvm::Value* x) {
|
||||||
if (hlo->operand(0)->shape().element_type() != F32) {
|
return EmitReducePrecisionIR(
|
||||||
return Unimplemented("reduce-precision only implemented for F32");
|
/*src_ty=*/hlo->operand(0)->shape().element_type(), x,
|
||||||
}
|
/*dest_exponent_bits=*/hlo->exponent_bits(),
|
||||||
return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
|
/*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
|
||||||
/*mantissa_bits=*/hlo->mantissa_bits(), b_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
|
static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
|
||||||
@ -1990,7 +2014,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
llvm::Value* float_val =
|
llvm::Value* float_val =
|
||||||
b_->CreateUIToFP(elem_index_linear, float_ir_type);
|
b_->CreateUIToFP(elem_index_linear, float_ir_type);
|
||||||
if (component_element_type == BF16) {
|
if (component_element_type == BF16) {
|
||||||
iota_result = EmitF32ToBF16(float_val, b_);
|
TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
|
||||||
} else {
|
} else {
|
||||||
iota_result = float_val;
|
iota_result = float_val;
|
||||||
}
|
}
|
||||||
|
@ -67,6 +67,30 @@ T ToArithmeticSafeType(T t) {
|
|||||||
return std::move(t);
|
return std::move(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UintWithSize<N> gets an unsigned integer with the given size in bytes.
|
||||||
|
template <size_t Bytes>
|
||||||
|
struct UintWithSize {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UintWithSize<1> {
|
||||||
|
using type = uint8;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UintWithSize<2> {
|
||||||
|
using type = uint16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UintWithSize<4> {
|
||||||
|
using type = uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UintWithSize<8> {
|
||||||
|
using type = uint64;
|
||||||
|
};
|
||||||
|
|
||||||
// Templated DfsHloVisitor for use by HloEvaluator.
|
// Templated DfsHloVisitor for use by HloEvaluator.
|
||||||
//
|
//
|
||||||
// Typically ReturnT here indicates the resulting literal type of each evaluated
|
// Typically ReturnT here indicates the resulting literal type of each evaluated
|
||||||
@ -2389,48 +2413,57 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return HandleCos<ElementwiseT>(cos);
|
return HandleCos<ElementwiseT>(cos);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT, typename std::enable_if<std::is_same<
|
template <typename NativeT,
|
||||||
float, NativeT>::value>::type* = nullptr>
|
typename std::enable_if<
|
||||||
|
std::is_same<NativeT, float>::value ||
|
||||||
|
std::is_same<NativeT, double>::value>::type* = nullptr>
|
||||||
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[reduce_precision],
|
parent_->evaluated_[reduce_precision],
|
||||||
ElementWiseUnaryOp(reduce_precision, [reduce_precision](
|
ElementWiseUnaryOp(reduce_precision, [&](ElementwiseT elem) {
|
||||||
ElementwiseT elem) {
|
const uint32 src_mantissa_bits =
|
||||||
uint32_t value_as_int = absl::bit_cast<uint32_t>(elem);
|
std::numeric_limits<NativeT>::digits - 1;
|
||||||
const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
|
const uint32 src_exponent_bits =
|
||||||
const uint32_t exponent_bits = reduce_precision->exponent_bits();
|
8 * sizeof(NativeT) - src_mantissa_bits - 1;
|
||||||
|
const uint32 dest_mantissa_bits = reduce_precision->mantissa_bits();
|
||||||
|
const uint32 dest_exponent_bits = reduce_precision->exponent_bits();
|
||||||
|
|
||||||
|
using Uint = typename UintWithSize<sizeof(NativeT)>::type;
|
||||||
|
Uint value_as_int = absl::bit_cast<Uint>(elem);
|
||||||
|
|
||||||
// Code is based on the CPU/GPU implementation in LLVM-emitting code.
|
// Code is based on the CPU/GPU implementation in LLVM-emitting code.
|
||||||
//
|
//
|
||||||
// Bits in float type:
|
// Bits in float32 type:
|
||||||
// mantissa : bits [0:22]
|
// mantissa : bits [0:22]
|
||||||
// exponent : bits [23:30]
|
// exponent : bits [23:30]
|
||||||
// sign : bits [31]
|
// sign : bits [31]
|
||||||
if (mantissa_bits < 23) {
|
if (dest_mantissa_bits < src_mantissa_bits) {
|
||||||
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
|
const Uint last_mantissa_bit_mask =
|
||||||
|
Uint{1} << (src_mantissa_bits - dest_mantissa_bits);
|
||||||
|
|
||||||
// Compute rounding bias for round-to-nearest with ties to even.
|
// Compute rounding bias for round-to-nearest with ties to even.
|
||||||
// This is equal to a base value of 0111... plus one bit if the last
|
// This is equal to a base value of 0111... plus one bit if the last
|
||||||
// remaining mantissa bit is 1.
|
// remaining mantissa bit is 1.
|
||||||
const uint32_t base_rounding_bias =
|
const Uint base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
||||||
(last_mantissa_bit_mask >> 1) - 1;
|
const Uint x_last_mantissa_bit =
|
||||||
const uint32_t x_last_mantissa_bit =
|
(value_as_int & last_mantissa_bit_mask) >>
|
||||||
(value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
|
(src_mantissa_bits - dest_mantissa_bits);
|
||||||
const uint32_t x_rounding_bias =
|
const Uint x_rounding_bias =
|
||||||
x_last_mantissa_bit + base_rounding_bias;
|
x_last_mantissa_bit + base_rounding_bias;
|
||||||
|
|
||||||
// Add rounding bias, and mask out truncated bits. Note that the
|
// Add rounding bias, and mask out truncated bits. Note that the
|
||||||
// case where adding the rounding bias overflows into the exponent
|
// case where adding the rounding bias overflows into the exponent
|
||||||
// bits is correct; the non-masked mantissa bits will all be zero,
|
// bits is correct; the non-masked mantissa bits will all be zero,
|
||||||
// and the exponent will be incremented by one.
|
// and the exponent will be incremented by one.
|
||||||
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
const Uint truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||||
value_as_int = value_as_int + x_rounding_bias;
|
value_as_int = value_as_int + x_rounding_bias;
|
||||||
value_as_int = value_as_int & truncation_mask;
|
value_as_int = value_as_int & truncation_mask;
|
||||||
}
|
}
|
||||||
if (exponent_bits < 8) {
|
if (dest_exponent_bits < src_exponent_bits) {
|
||||||
// Masks for f32 values.
|
// Masks for f32 values.
|
||||||
const uint32_t f32_sign_bit_mask = 1u << 31;
|
const Uint sign_bit_mask = Uint{1} << 8 * sizeof(NativeT) - 1;
|
||||||
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
const Uint exp_bits_mask = (Uint{1 << src_exponent_bits} - 1)
|
||||||
|
<< src_mantissa_bits;
|
||||||
|
|
||||||
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
|
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
|
||||||
// most- significant bit -- is equal to 1.0f for all exponent sizes.
|
// most- significant bit -- is equal to 1.0f for all exponent sizes.
|
||||||
@ -2444,23 +2477,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
// is (2^7-1) - 2^(n-1)-1.
|
// is (2^7-1) - 2^(n-1)-1.
|
||||||
//
|
//
|
||||||
// Note that we have already checked that exponents_bits >= 1.
|
// Note that we have already checked that exponents_bits >= 1.
|
||||||
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
const Uint exponent_bias = (Uint{1} << (src_exponent_bits - 1)) - 1;
|
||||||
const uint32_t reduced_exponent_bias =
|
const Uint reduced_exponent_bias =
|
||||||
(1 << (exponent_bits - 1)) - 1;
|
(1 << (dest_exponent_bits - 1)) - 1;
|
||||||
const uint32_t reduced_max_exponent =
|
const Uint reduced_max_exponent =
|
||||||
f32_exponent_bias + reduced_exponent_bias;
|
exponent_bias + reduced_exponent_bias;
|
||||||
const uint32_t reduced_min_exponent =
|
const Uint reduced_min_exponent =
|
||||||
f32_exponent_bias - reduced_exponent_bias;
|
exponent_bias - reduced_exponent_bias;
|
||||||
|
|
||||||
// Do we overflow or underflow?
|
// Do we overflow or underflow?
|
||||||
const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
|
const Uint x_exponent = value_as_int & exp_bits_mask;
|
||||||
const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
|
const bool x_overflows =
|
||||||
|
x_exponent > (reduced_max_exponent << src_mantissa_bits);
|
||||||
const bool x_underflows =
|
const bool x_underflows =
|
||||||
x_exponent <= (reduced_min_exponent << 23);
|
x_exponent <= (reduced_min_exponent << src_mantissa_bits);
|
||||||
|
|
||||||
// Compute appropriately-signed values of zero and infinity.
|
// Compute appropriately-signed values of zero and infinity.
|
||||||
const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
|
const Uint x_signed_zero = value_as_int & sign_bit_mask;
|
||||||
const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
|
const Uint x_signed_inf = x_signed_zero | exp_bits_mask;
|
||||||
|
|
||||||
// Force to zero or infinity if overflow or underflow. (Note that
|
// Force to zero or infinity if overflow or underflow. (Note that
|
||||||
// this truncates all denormal values to zero, rather than rounding
|
// this truncates all denormal values to zero, rather than rounding
|
||||||
@ -2469,23 +2503,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
value_as_int = x_underflows ? x_signed_zero : value_as_int;
|
value_as_int = x_underflows ? x_signed_zero : value_as_int;
|
||||||
}
|
}
|
||||||
|
|
||||||
float reduced_result = absl::bit_cast<float>(value_as_int);
|
NativeT reduced_result = absl::bit_cast<NativeT>(value_as_int);
|
||||||
if (std::isnan(elem)) {
|
if (std::isnan(elem)) {
|
||||||
reduced_result = mantissa_bits > 0
|
reduced_result = dest_mantissa_bits > 0
|
||||||
? elem
|
? elem
|
||||||
: std::numeric_limits<float>::infinity();
|
: std::numeric_limits<NativeT>::infinity();
|
||||||
}
|
}
|
||||||
return reduced_result;
|
return reduced_result;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT, typename std::enable_if<std::is_same<
|
|
||||||
double, NativeT>::value>::type* = nullptr>
|
|
||||||
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
|
||||||
return InvalidArgument("Double is not supported for reduce precision");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename NativeT,
|
typename NativeT,
|
||||||
typename std::enable_if<std::is_integral<NativeT>::value ||
|
typename std::enable_if<std::is_integral<NativeT>::value ||
|
||||||
|
@ -39,189 +39,479 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Tests to confirm that the ReducePrecision operation produces the expected
|
// Testcases in this file work as follows.
|
||||||
// numerical values.
|
|
||||||
class ReducePrecisionAccuracyTest : public ClientLibraryTestBase,
|
|
||||||
public ::testing::WithParamInterface<int> {
|
|
||||||
};
|
|
||||||
|
|
||||||
// For reduction to IEEE-f16, we want to test the following cases, in both
|
|
||||||
// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10
|
|
||||||
// mantissa bits.)
|
|
||||||
//
|
//
|
||||||
// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a
|
// for ty in {f16, bf16, f32, f64}:
|
||||||
// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only
|
// for (operation_index, (e, m)) in \
|
||||||
// variants of IEEE-f16.
|
// enumerate(zip(ty_exponent_sizes, ty_mantissa_sizes)):
|
||||||
static const int exponent_sizes[] = {8, 5, 5, 8};
|
//
|
||||||
static const int mantissa_sizes[] = {23, 10, 23, 10};
|
// for testcase in ty_test_values:
|
||||||
|
// let expected = testcase[0]
|
||||||
|
// let input = testcase[operation_index]
|
||||||
|
//
|
||||||
|
// CHECK that XLA-reduce-precision(
|
||||||
|
// input, /*exp_bits=*/e, /*mantissa_bits=*/m) == expected
|
||||||
|
//
|
||||||
|
// Put into words:
|
||||||
|
//
|
||||||
|
// - ty_{exponent,mantissa}_sizes tell us the different ways we will reduce the
|
||||||
|
// precision of `ty`.
|
||||||
|
//
|
||||||
|
// - ty_test_values is a 2D array of testcases, each of which is
|
||||||
|
// len(ty_exponent_sizes) elements long. The first element corresponds to
|
||||||
|
// the input, and the j'th element corresponds to the expected output of
|
||||||
|
// doing a reduce-precision with parameters ty_{exponent,mantissa}_sizes[j].
|
||||||
|
//
|
||||||
|
// You'll note above that testcase[0] is reused as both expected and input when
|
||||||
|
// operation_index == 0. This implies that ty_{exponent,mantissa}_sizes[0] must
|
||||||
|
// be equal to `ty`'s exponent/mantissa size, making the reduce-precision op
|
||||||
|
// tested a nop.
|
||||||
|
|
||||||
string TestDataToString(const ::testing::TestParamInfo<int> data) {
|
// We want to test IEEE-f16 (a nop), cases that reduce just the
|
||||||
int i = data.param;
|
// mantissa/exponent, and a case that reduces both.
|
||||||
return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i],
|
//
|
||||||
"_mantissa_bits");
|
// We don't have a lot of tests here, relying instead on the coverage we have of
|
||||||
}
|
// f32 and f64.
|
||||||
|
//
|
||||||
|
// Note: The hypothetical float(3,7) type we're "converting" to would have:
|
||||||
|
// max exp = 2^(3-1) - 1 = 3
|
||||||
|
// min exp = -max_exp + 1 = -2
|
||||||
|
static const int f16_exponent_sizes[] = {5, 3, 3, 5};
|
||||||
|
static const int f16_mantissa_sizes[] = {10, 7, 10, 7};
|
||||||
|
|
||||||
// The FPVAL macro allows us to write out the binary representation of the
|
// The F16VAL macro lets us write out the binary representation of an f16 in a
|
||||||
// input and expected values in a more readable manner. The mantissa bits
|
// more readable manner, separating out the exponent and mantissa.
|
||||||
// are separated into the "high" bits (retained with reduction to IEEE-f16)
|
#define F16VAL(EXPONENT, MANTISSA) ((EXPONENT << 10) + (MANTISSA))
|
||||||
// and the "low" bits (truncated with reduction to IEEE-f16).
|
|
||||||
#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
|
||||||
((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA))
|
|
||||||
|
|
||||||
// Each element in the test-value array consists of four numbers. The first is
|
static const uint16 f16_test_values[][4] = {
|
||||||
// the input value and the following are the expected output values for the
|
|
||||||
// various precision-reduction cases.
|
|
||||||
static const uint32_t test_values[][4] = {
|
|
||||||
// True zero.
|
// True zero.
|
||||||
{
|
{
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000) // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
},
|
},
|
||||||
// Largest exponent that underflows to zero.
|
// One.
|
||||||
{
|
{
|
||||||
FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05
|
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||||
FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05
|
F16VAL(0b01111, 0b0000000000), // 1.0
|
||||||
},
|
},
|
||||||
// Largest value that rounds to a denormal and thus clamps to zero.
|
// Largest exponent that underflows to zero is -3, which is encoded as
|
||||||
|
// -3 + 15 = 12
|
||||||
{
|
{
|
||||||
FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05
|
F16VAL(0b01100, 0b0000000000), // 2^-3
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0
|
||||||
FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05
|
F16VAL(0b01100, 0b0000000000), // 2^-3
|
||||||
},
|
},
|
||||||
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||||
// up and incrementing the exponent out of the denormal range.
|
// up and incrementing the exponent out of the denormal range.
|
||||||
{
|
{
|
||||||
FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05
|
F16VAL(0b01100, 0b1111111100), // 1020 * 2^-3
|
||||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
F16VAL(0b01101, 0b0000000000), // 2^-2
|
||||||
FPVAL(00000000, 0000000000, 0000000000000), // 0.0
|
F16VAL(0b00000, 0b0000000000), // 0
|
||||||
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
|
F16VAL(0b01101, 0b0000000000), // 2^-2
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// We want to test bfloat16 (a nop), cases that reduce just the
|
||||||
|
// mantissa/exponent, and a case that reduces both.
|
||||||
|
//
|
||||||
|
// We don't have a lot of tests here, relying instead on the coverage we have of
|
||||||
|
// f32 and f64.
|
||||||
|
static const int bf16_exponent_sizes[] = {8, 5, 5, 8};
|
||||||
|
static const int bf16_mantissa_sizes[] = {7, 5, 7, 5};
|
||||||
|
|
||||||
|
// The BF16VAL macro lets us write out the binary representation of a bf16 in a
|
||||||
|
// more readable manner, separating out the exponent and mantissa.
|
||||||
|
#define BF16VAL(EXPONENT, MANTISSA) ((EXPONENT << 7) + (MANTISSA))
|
||||||
|
|
||||||
|
static const uint16 bf16_test_values[][4] = {
|
||||||
|
// True zero.
|
||||||
|
{
|
||||||
|
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
|
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
|
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
|
BF16VAL(0b00000, 0b0000000000), // 0.0
|
||||||
|
},
|
||||||
|
// One.
|
||||||
|
{
|
||||||
|
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||||
|
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||||
|
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||||
|
BF16VAL(0b01111111, 0b0000000), // 1.0
|
||||||
|
},
|
||||||
|
// Largest exponent that underflows to zero.
|
||||||
|
{
|
||||||
|
BF16VAL(0b01110000, 0b0000000), // 3.05176e-05
|
||||||
|
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||||
|
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||||
|
BF16VAL(0b01110000, 0b0000000) // 3.05176e-05
|
||||||
|
},
|
||||||
|
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||||
|
// up and incrementing the exponent out of the denormal range.
|
||||||
|
{
|
||||||
|
BF16VAL(0b01110000, 0b1111110), // 6.05583e-05
|
||||||
|
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
|
||||||
|
BF16VAL(0b00000000, 0b0000000), // 0.0
|
||||||
|
BF16VAL(0b01110001, 0b0000000), // 6.10352e-05
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// We want to test IEEE-f32 (a no-op), IEEE-f16, and exponent-reduction-only and
|
||||||
|
// mantissa-reduction-only variants of IEEE-f16.
|
||||||
|
static const int f32_exponent_sizes[] = {8, 5, 5, 8};
|
||||||
|
static const int f32_mantissa_sizes[] = {23, 10, 23, 10};
|
||||||
|
|
||||||
|
// The F32VAL macro allows us to write out the binary representation of the
|
||||||
|
// input and expected values in a more readable manner. The mantissa bits
|
||||||
|
// are separated into the "high" bits (retained with reduction to IEEE-f16)
|
||||||
|
// and the "low" bits (truncated with reduction to IEEE-f16).
|
||||||
|
#define F32VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
||||||
|
((EXPONENT << 23) + (HIGH_MANTISSA << 13) + (LOW_MANTISSA))
|
||||||
|
|
||||||
|
static const uint32 f32_test_values[][4] = {
|
||||||
|
// True zero.
|
||||||
|
{
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000) // 0.0
|
||||||
|
},
|
||||||
|
// Largest exponent that underflows to zero.
|
||||||
|
{
|
||||||
|
F32VAL(0b01110000, 0b0000000000, 0b0000000000000), // 3.05176e-05
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b01110000, 0b0000000000, 0b0000000000000) // 3.05176e-05
|
||||||
|
},
|
||||||
|
// Largest value that rounds to a denormal and thus clamps to zero.
|
||||||
|
{
|
||||||
|
F32VAL(0b01110000, 0b1111111111, 0b0111111111111), // 6.10203e-05
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b01110000, 0b1111111111, 0b0000000000000) // 6.10054e-05
|
||||||
|
},
|
||||||
|
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||||
|
// up and incrementing the exponent out of the denormal range.
|
||||||
|
{
|
||||||
|
F32VAL(0b01110000, 0b1111111111, 0b1000000000000), // 6.10203e-05
|
||||||
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||||
|
F32VAL(0b00000000, 0b0000000000, 0b0000000000000), // 0.0
|
||||||
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
|
||||||
},
|
},
|
||||||
// Smallest value that doesn't underflow to zero even without mantissa
|
// Smallest value that doesn't underflow to zero even without mantissa
|
||||||
// rounding.
|
// rounding.
|
||||||
{
|
{
|
||||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||||
FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000), // 6.10352e-05
|
||||||
FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05
|
F32VAL(0b01110001, 0b0000000000, 0b0000000000000) // 6.10352e-05
|
||||||
},
|
},
|
||||||
// One (to make sure bias-handling is done correctly.
|
// One (to make sure bias-handling is done correctly).
|
||||||
{
|
{
|
||||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||||
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
|
||||||
},
|
},
|
||||||
// Values in a space where ties round down due to ties-to-even:
|
// Values in a space where ties round down due to ties-to-even:
|
||||||
// Value with highest mantissa that rounds down.
|
// Value with highest mantissa that rounds down.
|
||||||
{
|
{
|
||||||
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
|
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
|
||||||
FPVAL(01111111, 0000000000, 0000000000000), // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000), // 1.0
|
||||||
FPVAL(01111111, 0000000000, 1000000000000), // 1.00049
|
F32VAL(0b01111111, 0b0000000000, 0b1000000000000), // 1.00049
|
||||||
FPVAL(01111111, 0000000000, 0000000000000) // 1.0
|
F32VAL(0b01111111, 0b0000000000, 0b0000000000000) // 1.0
|
||||||
},
|
},
|
||||||
// Value with lowest mantissa that rounds up.
|
// Value with lowest mantissa that rounds up.
|
||||||
{
|
{
|
||||||
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
|
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
|
||||||
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
|
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
|
||||||
FPVAL(01111111, 0000000000, 1000000000001), // 1.00049
|
F32VAL(0b01111111, 0b0000000000, 0b1000000000001), // 1.00049
|
||||||
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
|
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
|
||||||
},
|
},
|
||||||
// Values in a space where ties round up due to ties-to-even:
|
// Values in a space where ties round up due to ties-to-even:
|
||||||
// Value with highest mantissa that rounds down.
|
// Value with highest mantissa that rounds down.
|
||||||
{
|
{
|
||||||
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
|
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
|
||||||
FPVAL(01111111, 0000000001, 0000000000000), // 1.00098
|
F32VAL(0b01111111, 0b0000000001, 0b0000000000000), // 1.00098
|
||||||
FPVAL(01111111, 0000000001, 0111111111111), // 1.00146
|
F32VAL(0b01111111, 0b0000000001, 0b0111111111111), // 1.00146
|
||||||
FPVAL(01111111, 0000000001, 0000000000000) // 1.00098
|
F32VAL(0b01111111, 0b0000000001, 0b0000000000000) // 1.00098
|
||||||
},
|
},
|
||||||
// Value with a mantissa that rounds up.
|
// Value with a mantissa that rounds up.
|
||||||
{
|
{
|
||||||
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
|
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
|
||||||
FPVAL(01111111, 0000000010, 0000000000000), // 1.00195
|
F32VAL(0b01111111, 0b0000000010, 0b0000000000000), // 1.00195
|
||||||
FPVAL(01111111, 0000000001, 1000000000000), // 1.00146
|
F32VAL(0b01111111, 0b0000000001, 0b1000000000000), // 1.00146
|
||||||
FPVAL(01111111, 0000000010, 0000000000000) // 1.00195
|
F32VAL(0b01111111, 0b0000000010, 0b0000000000000) // 1.00195
|
||||||
},
|
},
|
||||||
// Largest value that does not overflow to infinity.
|
// Largest value that does not overflow to infinity.
|
||||||
{
|
{
|
||||||
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
|
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
|
||||||
FPVAL(10001110, 1111111111, 0000000000000), // 65504.0
|
F32VAL(0b10001110, 0b1111111111, 0b0000000000000), // 65504.0
|
||||||
FPVAL(10001110, 1111111111, 0111111111111), // 65520.0
|
F32VAL(0b10001110, 0b1111111111, 0b0111111111111), // 65520.0
|
||||||
FPVAL(10001110, 1111111111, 0000000000000) // 65504.0
|
F32VAL(0b10001110, 0b1111111111, 0b0000000000000) // 65504.0
|
||||||
},
|
},
|
||||||
// Smallest value that overflows to infinity due to mantissa rounding up.
|
// Smallest value that overflows to infinity due to mantissa rounding up.
|
||||||
{
|
{
|
||||||
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
|
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(10001110, 1111111111, 1000000000000), // 65520.0
|
F32VAL(0b10001110, 0b1111111111, 0b1000000000000), // 65520.0
|
||||||
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
|
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
|
||||||
},
|
},
|
||||||
// Smallest value that overflows to infinity, without mantissa rounding.
|
// Smallest value that overflows to infinity, without mantissa rounding.
|
||||||
{
|
{
|
||||||
FPVAL(10001111, 0000000000, 0000000000000), // 65536.0
|
F32VAL(0b10001111, 0b0000000000, 0b0000000000000), // 65536.0
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(10001111, 0000000000, 0000000000000) // 65536.0
|
F32VAL(0b10001111, 0b0000000000, 0b0000000000000) // 65536.0
|
||||||
},
|
},
|
||||||
// Smallest value that overflows to infinity due to mantissa rounding up,
|
// Smallest value that overflows to infinity due to mantissa rounding up,
|
||||||
// even when exponent bits aren't reduced.
|
// even when exponent bits aren't reduced.
|
||||||
{
|
{
|
||||||
FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38
|
F32VAL(0b11111110, 0b1111111111, 0b1000000000000), // 3.40199e+38
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000) // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
|
||||||
},
|
},
|
||||||
// True infinity.
|
// True infinity.
|
||||||
{
|
{
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000), // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000), // Inf
|
||||||
FPVAL(11111111, 0000000000, 0000000000000) // Inf
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000000) // Inf
|
||||||
},
|
},
|
||||||
// NAN with a 1 in the preserved bits.
|
// NAN with a 1 in the preserved bits.
|
||||||
{
|
{
|
||||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||||
FPVAL(11111111, 1000000000, 0000000000000), // NaN
|
F32VAL(0b11111111, 0b1000000000, 0b0000000000000), // NaN
|
||||||
FPVAL(11111111, 1000000000, 0000000000000) // NaN
|
F32VAL(0b11111111, 0b1000000000, 0b0000000000000) // NaN
|
||||||
},
|
},
|
||||||
// NAN with a 1 in the truncated bits.
|
// NAN with a 1 in the truncated bits.
|
||||||
{
|
{
|
||||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||||
FPVAL(11111111, 0000000000, 0000000000001), // NaN
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000001), // NaN
|
||||||
FPVAL(11111111, 0000000000, 0000000000001) // NaN
|
F32VAL(0b11111111, 0b0000000000, 0b0000000000001) // NaN
|
||||||
},
|
},
|
||||||
// NAN with all ones, causing rounding overflow.
|
// NAN with all ones, causing rounding overflow.
|
||||||
{
|
{
|
||||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||||
FPVAL(11111111, 1111111111, 1111111111111), // NaN
|
F32VAL(0b11111111, 0b1111111111, 0b1111111111111), // NaN
|
||||||
FPVAL(11111111, 1111111111, 1111111111111) // NaN
|
F32VAL(0b11111111, 0b1111111111, 0b1111111111111) // NaN
|
||||||
}};
|
}};
|
||||||
|
|
||||||
XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
|
// F64VAL is like F32VAL but for doubles.
|
||||||
int index = GetParam();
|
//
|
||||||
int exponent_bits = exponent_sizes[index];
|
// Here the "high" mantissa bits are those retained with reduction to IEEE-f32
|
||||||
int mantissa_bits = mantissa_sizes[index];
|
// (the first 23 bits), and the "low" bits are those truncated with reduction to
|
||||||
|
// IEEE-f32 (the remaining 29 bits).
|
||||||
|
#define F64VAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \
|
||||||
|
((uint64{EXPONENT} << 52) + (uint64{HIGH_MANTISSA} << 29) + \
|
||||||
|
uint64{LOW_MANTISSA})
|
||||||
|
|
||||||
std::vector<float> input_values;
|
// We want to test IEEE-f64 (a no-op), IEEE-f32, and exponent-reduction-only and
|
||||||
std::vector<float> expected_values;
|
// mantissa-reduction-only variants of IEEE-f32.
|
||||||
|
static const int f64_exponent_sizes[] = {11, 8, 8, 11};
|
||||||
|
static const int f64_mantissa_sizes[] = {52, 23, 52, 23};
|
||||||
|
|
||||||
const uint32_t sign_bit = 1u << 31;
|
static const uint64 f64_test_values[][4] = {
|
||||||
|
// True zero.
|
||||||
|
{
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
},
|
||||||
|
// Largest exponent that underflows to zero, namely -127 (encoded as
|
||||||
|
// -127 + 1023).
|
||||||
|
{
|
||||||
|
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b01110000000, 0x000000, 0x00000000), // 5.8774717541114375e-39
|
||||||
|
},
|
||||||
|
// Largest value that rounds to a denormal and thus clamps to zero.
|
||||||
|
{
|
||||||
|
F64VAL(0b01110000000, 0x7FFFFF, 0x0FFFFFFF), // 1.1754943157898258e-38
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b01110000000, 0x7FFFFF, 0x00000000), // 1.1754942807573643e-38
|
||||||
|
},
|
||||||
|
// Smallest value that doesn't underflow to zero, due to mantissa rounding
|
||||||
|
// up and incrementing the exponent out of the denormal range.
|
||||||
|
{
|
||||||
|
F64VAL(0b01110000000, 0x7FFFFF, 0x10000000), // 1.1754943157898259e-38
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||||
|
F64VAL(0b00000000000, 0x000000, 0x00000000), // 0.0
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
|
||||||
|
},
|
||||||
|
// Smallest value that doesn't underflow to zero even without mantissa
|
||||||
|
// rounding.
|
||||||
|
{
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000), // 1.1754943508222875e-38
|
||||||
|
F64VAL(0b01110000001, 0x000000, 0x00000000) // 1.1754943508222875e-38
|
||||||
|
},
|
||||||
|
// One (to make sure bias-handling is done correctly).
|
||||||
|
{
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
|
||||||
|
},
|
||||||
|
// Values in a space where ties round down due to ties-to-even:
|
||||||
|
// Value with highest mantissa that rounds down.
|
||||||
|
{
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000), // 1.0
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x10000000), // 1.0000000596046448
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x00000000) // 1.0
|
||||||
|
},
|
||||||
|
// Value with lowest mantissa that rounds up.
|
||||||
|
{
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
|
||||||
|
F64VAL(0b01111111111, 0x000000, 0x10000001), // 1.000000059604645
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
|
||||||
|
},
|
||||||
|
// Values in a space where ties round up due to ties-to-even:
|
||||||
|
// Value with highest mantissa that rounds down.
|
||||||
|
{
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x00000000), // 1.0000001192092896
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x0fffffff), // 1.0000001788139341
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x00000000) // 1.0000001192092896
|
||||||
|
},
|
||||||
|
// Value with a mantissa that rounds up.
|
||||||
|
{
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
|
||||||
|
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
|
||||||
|
F64VAL(0b01111111111, 0x000001, 0x10000000), // 1.0000001788139343
|
||||||
|
F64VAL(0b01111111111, 0x000002, 0x00000000), // 1.0000002384185791
|
||||||
|
},
|
||||||
|
// Largest value that does not overflow to infinity.
|
||||||
|
{
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x0fffffff), // 3.4028235677973362e+38
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x00000000), // 3.4028234663852886e+38
|
||||||
|
},
|
||||||
|
// Smallest value that overflows to infinity due to mantissa rounding up.
|
||||||
|
{
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b10001111110, 0x7fffff, 0x10000000), // 3.4028235677973366e+38
|
||||||
|
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
|
||||||
|
},
|
||||||
|
// Smallest value that overflows to infinity, without mantissa rounding.
|
||||||
|
{
|
||||||
|
F64VAL(0b10001111111, 0x000000, 0x00000000), // 3.4028236692093846e+38
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b10001111111, 0x000000, 0x00000000) // 3.4028236692093846e+38
|
||||||
|
},
|
||||||
|
// Smallest value that overflows to infinity due to mantissa rounding up,
|
||||||
|
// even when exponent bits aren't reduced.
|
||||||
|
{
|
||||||
|
F64VAL(0b11111111110, 0x7fffff, 0x10000000), // 1.7976930812868855e+308
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000) // Inf
|
||||||
|
},
|
||||||
|
// True infinity.
|
||||||
|
{
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000000), // Inf
|
||||||
|
},
|
||||||
|
// NAN with a 1 in the preserved bits.
|
||||||
|
{
|
||||||
|
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||||
|
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||||
|
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||||
|
F64VAL(0b11111111111, 0x800000, 0x00000000), // -0
|
||||||
|
},
|
||||||
|
// NAN with a 1 in the truncated bits.
|
||||||
|
{
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x000000, 0x00000001), // NaN
|
||||||
|
},
|
||||||
|
// NAN with all ones, causing rounding overflow.
|
||||||
|
{
|
||||||
|
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||||
|
F64VAL(0b11111111111, 0x7fffff, 0x1fffffff), // NaN
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReducedPrecisionAccuracyTest : public ClientLibraryTestBase,
|
||||||
|
public ::testing::WithParamInterface<int> {
|
||||||
|
protected:
|
||||||
|
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
|
||||||
|
void DoIt(int exponent_bits, int mantissa_bits,
|
||||||
|
const Uint (&test_values)[kNumInputs][kNumTestcases],
|
||||||
|
int operation_index);
|
||||||
|
};
|
||||||
|
|
||||||
|
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionHalf) {
|
||||||
|
int operation_index = GetParam();
|
||||||
|
DoIt<Eigen::half, uint16>(f16_exponent_sizes[operation_index],
|
||||||
|
f16_mantissa_sizes[operation_index],
|
||||||
|
f16_test_values, operation_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionBfloat16) {
|
||||||
|
int operation_index = GetParam();
|
||||||
|
DoIt<bfloat16, uint16>(bf16_exponent_sizes[operation_index],
|
||||||
|
bf16_mantissa_sizes[operation_index], bf16_test_values,
|
||||||
|
operation_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionFloat) {
|
||||||
|
int operation_index = GetParam();
|
||||||
|
DoIt<float, uint32>(f32_exponent_sizes[operation_index],
|
||||||
|
f32_mantissa_sizes[operation_index], f32_test_values,
|
||||||
|
operation_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_P(ReducedPrecisionAccuracyTest, ReducePrecisionDouble) {
|
||||||
|
int operation_index = GetParam();
|
||||||
|
DoIt<double, uint64>(f64_exponent_sizes[operation_index],
|
||||||
|
f64_mantissa_sizes[operation_index], f64_test_values,
|
||||||
|
operation_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fp, typename Uint, int kNumTestcases, int kNumInputs>
|
||||||
|
void ReducedPrecisionAccuracyTest::DoIt(
|
||||||
|
int exponent_bits, int mantissa_bits,
|
||||||
|
const Uint (&test_values)[kNumInputs][kNumTestcases], int operation_index) {
|
||||||
|
SCOPED_TRACE(absl::StrFormat("operation_index %d", operation_index));
|
||||||
|
SCOPED_TRACE(absl::StrFormat("%d exponent bits, %d mantissa bits",
|
||||||
|
exponent_bits, mantissa_bits));
|
||||||
|
|
||||||
|
std::vector<Fp> input_values;
|
||||||
|
std::vector<Fp> expected_values;
|
||||||
|
|
||||||
|
const Uint sign_bit = Uint{1} << (sizeof(Fp) * 8 - 1);
|
||||||
for (const auto& test_value : test_values) {
|
for (const auto& test_value : test_values) {
|
||||||
// Add positive values.
|
// Add positive values.
|
||||||
input_values.push_back(absl::bit_cast<float>(test_value[0]));
|
input_values.push_back(absl::bit_cast<Fp>(test_value[0]));
|
||||||
expected_values.push_back(absl::bit_cast<float>(test_value[index]));
|
expected_values.push_back(absl::bit_cast<Fp>(test_value[operation_index]));
|
||||||
// Add negative values. We do this in the bitwise representation so as to
|
// Add negative values. We do this in the bitwise representation so as to
|
||||||
// avoid problems with NaN handling.
|
// avoid problems with NaN handling.
|
||||||
input_values.push_back(absl::bit_cast<float>(test_value[0] ^ sign_bit));
|
input_values.push_back(absl::bit_cast<Fp, Uint>(test_value[0] ^ sign_bit));
|
||||||
expected_values.push_back(
|
expected_values.push_back(
|
||||||
absl::bit_cast<float>(test_value[index] ^ sign_bit));
|
absl::bit_cast<Fp, Uint>(test_value[operation_index] ^ sign_bit));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is required for proper handling of NaN values.
|
// This is required for proper handling of NaN values.
|
||||||
@ -229,19 +519,18 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
|
|||||||
|
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
|
Literal a_literal = LiteralUtil::CreateR1<Fp>({input_values});
|
||||||
std::unique_ptr<GlobalData> a_data =
|
std::unique_ptr<GlobalData> a_data =
|
||||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||||
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
auto a = Parameter(&builder, 0, a_literal.shape(), "a");
|
||||||
|
|
||||||
ReducePrecision(a, exponent_bits, mantissa_bits);
|
ReducePrecision(a, exponent_bits, mantissa_bits);
|
||||||
|
|
||||||
ComputeAndCompareR1<float>(&builder, expected_values, {a_data.get()});
|
ComputeAndCompareR1<Fp>(&builder, expected_values, {a_data.get()});
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest,
|
INSTANTIATE_TEST_CASE_P(ReducedPrecisionAccuracyTest,
|
||||||
ReducePrecisionAccuracyTest,
|
ReducedPrecisionAccuracyTest, ::testing::Range(0, 4));
|
||||||
::testing::Values(0, 1, 2, 3), TestDataToString);
|
|
||||||
|
|
||||||
// Tests to confirm that the compiler optimization functions add the expected
|
// Tests to confirm that the compiler optimization functions add the expected
|
||||||
// ReducePrecisionInsertion passes.
|
// ReducePrecisionInsertion passes.
|
||||||
|
Loading…
Reference in New Issue
Block a user