Tweak round_to_bfloat16 to make it vectorizable.

This simplifies control flow by handling positive and
negative denormals separately. Should be ~40% faster.

PiperOrigin-RevId: 312095390
Change-Id: I5b6388e48b8c217edb0fc4fe14c3add64fb52c65
This commit is contained in:
Ilya Tokar 2020-05-18 09:35:23 -07:00 committed by TensorFlower Gardener
parent 9c49cda7d9
commit cfdb943405

View File

@ -194,171 +194,170 @@ struct bfloat16 {
input = f.u;
bfloat16 output;
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
//
//
// Least significant bit of resulting bfloat.
uint32_t lsb = (input >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<uint16_t>(input >> 16);
if ((f.u & 0xff800000u) == 0) {
// Flush positive denormal to 0
output.value = 0x0;
}
if ((f.u & 0xff800000u) == 0x80000000u) {
// Flush negative denormal to -0
output.value = 0x8000;
}
if (float_isnan(v)) {
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
// this makes sure after truncation we don't end up with an inf.
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
output.value = 0x7fc0;
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
// Flush denormal to +/- 0.0
output.value = std::signbit(v) ? 0x8000 : 0;
} else {
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
//
//
// Least significant bit of resulting bfloat.
uint32_t lsb = (input >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<uint16_t>(input >> 16);
output.value = NAN_VALUE;
}
return output;
}