Tweak round_to_bfloat16 to make it vectorizable.
This simplifies control flow by handling positive and negative denormals separately. Should be ~40% faster. PiperOrigin-RevId: 312095390 Change-Id: I5b6388e48b8c217edb0fc4fe14c3add64fb52c65
This commit is contained in:
parent
9c49cda7d9
commit
cfdb943405
@ -194,171 +194,170 @@ struct bfloat16 {
|
||||
input = f.u;
|
||||
bfloat16 output;
|
||||
|
||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||
// reduces expected error when we convert a large number of floats. Here
|
||||
// is how it works:
|
||||
//
|
||||
// Definitions:
|
||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
||||
// with the following tags:
|
||||
//
|
||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
||||
//
|
||||
// S: Sign bit.
|
||||
// E: Exponent bits.
|
||||
// F: First 6 bits of fraction.
|
||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
||||
// rest of the float32. This is also the 7th bit of fraction
|
||||
// R: Rounding bit, 8th bit of fraction.
|
||||
// T: Sticky bits, rest of fraction, 15 bits.
|
||||
//
|
||||
// To round half to nearest even, there are 3 cases where we want to round
|
||||
// down (simply truncate the result of the bits away, which consists of
|
||||
// rounding bit and sticky bits) and two cases where we want to round up
|
||||
// (truncate then add one to the result).
|
||||
//
|
||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
||||
// truncates the last 16 bits away.
|
||||
//
|
||||
// To understand how it works, we can analyze this algorithm case by case:
|
||||
//
|
||||
// 1. L = 0, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input may create any carry, depending on
|
||||
// whether there is any value set to 1 in T bits.
|
||||
// - R may be set to 1 if there is a carry.
|
||||
// - L remains 0.
|
||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
||||
// this algorithm.
|
||||
//
|
||||
// 2. L = 1, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits but
|
||||
// adds 1 to rounding bit.
|
||||
// - L remains 1.
|
||||
//
|
||||
// 3. L = 0, R = 1, all of T are 0:
|
||||
// Expect: round down, this is exactly at half, the result is already
|
||||
// even (L=0).
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
||||
// doesn't create a carry.
|
||||
// - R remains 1.
|
||||
// - L remains 0.
|
||||
//
|
||||
// 4. L = 1, R = 1:
|
||||
// Expect: round up, this is exactly at half, the result needs to be
|
||||
// round to the next even number.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
||||
// creates a carry from rounding bit.
|
||||
// - The carry sets L to 0, creates another carry bit and propagate
|
||||
// forward to F bits.
|
||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
||||
// bits, which then creates the minimum value with the next exponent
|
||||
// value. Note that we won't have the case where exponents are all 1,
|
||||
// since that's either a NaN (handled in the other if condition) or inf
|
||||
// (handled in case 1).
|
||||
//
|
||||
// 5. L = 0, R = 1, any of T is 1:
|
||||
// Expect: round up, this is greater than half.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
||||
// sets rounding bit to 0, then create another carry.
|
||||
// - The second carry sets L to 1.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// Exact half value that is already even:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
||||
//
|
||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
||||
// carry is created into F and L:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
// Exact half value, round to next even number:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// which then propagates into L and F:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
//
|
||||
// Max denormal value round to min normal value:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||
//
|
||||
// Max normal value round to Inf:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||
//
|
||||
//
|
||||
// Least significant bit of resulting bfloat.
|
||||
uint32_t lsb = (input >> 16) & 1;
|
||||
uint32_t rounding_bias = 0x7fff + lsb;
|
||||
input += rounding_bias;
|
||||
output.value = static_cast<uint16_t>(input >> 16);
|
||||
if ((f.u & 0xff800000u) == 0) {
|
||||
// Flush positive denormal to 0
|
||||
output.value = 0x0;
|
||||
}
|
||||
if ((f.u & 0xff800000u) == 0x80000000u) {
|
||||
// Flush negative denormal to -0
|
||||
output.value = 0x8000;
|
||||
}
|
||||
if (float_isnan(v)) {
|
||||
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
|
||||
// this makes sure after truncation we don't end up with an inf.
|
||||
//
|
||||
// qNaN magic: All exponent bits set + most significant bit of fraction
|
||||
// set.
|
||||
output.value = 0x7fc0;
|
||||
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
|
||||
// Flush denormal to +/- 0.0
|
||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
||||
} else {
|
||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||
// reduces expected error when we convert a large number of floats. Here
|
||||
// is how it works:
|
||||
//
|
||||
// Definitions:
|
||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
||||
// with the following tags:
|
||||
//
|
||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
||||
//
|
||||
// S: Sign bit.
|
||||
// E: Exponent bits.
|
||||
// F: First 6 bits of fraction.
|
||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
||||
// rest of the float32. This is also the 7th bit of fraction
|
||||
// R: Rounding bit, 8th bit of fraction.
|
||||
// T: Sticky bits, rest of fraction, 15 bits.
|
||||
//
|
||||
// To round half to nearest even, there are 3 cases where we want to round
|
||||
// down (simply truncate the result of the bits away, which consists of
|
||||
// rounding bit and sticky bits) and two cases where we want to round up
|
||||
// (truncate then add one to the result).
|
||||
//
|
||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
||||
// truncates the last 16 bits away.
|
||||
//
|
||||
// To understand how it works, we can analyze this algorithm case by case:
|
||||
//
|
||||
// 1. L = 0, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input may create any carry, depending on
|
||||
// whether there is any value set to 1 in T bits.
|
||||
// - R may be set to 1 if there is a carry.
|
||||
// - L remains 0.
|
||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
||||
// this algorithm.
|
||||
//
|
||||
// 2. L = 1, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits but
|
||||
// adds 1 to rounding bit.
|
||||
// - L remains 1.
|
||||
//
|
||||
// 3. L = 0, R = 1, all of T are 0:
|
||||
// Expect: round down, this is exactly at half, the result is already
|
||||
// even (L=0).
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
||||
// doesn't create a carry.
|
||||
// - R remains 1.
|
||||
// - L remains 0.
|
||||
//
|
||||
// 4. L = 1, R = 1:
|
||||
// Expect: round up, this is exactly at half, the result needs to be
|
||||
// round to the next even number.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
||||
// creates a carry from rounding bit.
|
||||
// - The carry sets L to 0, creates another carry bit and propagate
|
||||
// forward to F bits.
|
||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
||||
// bits, which then creates the minimum value with the next exponent
|
||||
// value. Note that we won't have the case where exponents are all 1,
|
||||
// since that's either a NaN (handled in the other if condition) or inf
|
||||
// (handled in case 1).
|
||||
//
|
||||
// 5. L = 0, R = 1, any of T is 1:
|
||||
// Expect: round up, this is greater than half.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
||||
// sets rounding bit to 0, then create another carry.
|
||||
// - The second carry sets L to 1.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// Exact half value that is already even:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
||||
//
|
||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
||||
// carry is created into F and L:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
// Exact half value, round to next even number:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// which then propagates into L and F:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
//
|
||||
// Max denormal value round to min normal value:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||
//
|
||||
// Max normal value round to Inf:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||
//
|
||||
//
|
||||
// Least significant bit of resulting bfloat.
|
||||
uint32_t lsb = (input >> 16) & 1;
|
||||
uint32_t rounding_bias = 0x7fff + lsb;
|
||||
input += rounding_bias;
|
||||
output.value = static_cast<uint16_t>(input >> 16);
|
||||
output.value = NAN_VALUE;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user