Replace tensorflow::bfloat16 with Eigen::bfloat16

This commit is contained in:
ShengYang1 2020-07-01 11:13:16 +08:00
parent aed5e3ea00
commit 3451f21c1e
10 changed files with 59 additions and 608 deletions

View File

@ -28,20 +28,20 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
@ -897,7 +897,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) {
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isBF16()) {
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isF32()) {

View File

@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
builder,
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
return ConstantR0<Eigen::bfloat16>(
builder, static_cast<Eigen::bfloat16>(
Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
case F64:
@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
Eigen::NumTraits<Eigen::half>::lowest());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::lowest());
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
case F32:
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
case F64:
@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
std::numeric_limits<Eigen::half>::min());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
return ConstantR0<Eigen::bfloat16>(
builder, std::numeric_limits<Eigen::bfloat16>::min());
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
case F64:
@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
Eigen::NumTraits<Eigen::half>::highest());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::highest());
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
case F64:
@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
case BF16:
return ConstantR0<bfloat16>(
builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
case F32:
return ConstantR0<float>(builder,
std::numeric_limits<float>::quiet_NaN());

View File

@ -2678,7 +2678,9 @@ struct MinMaxFiniteValue<Eigen::half> {
template <>
struct MinMaxFiniteValue<bfloat16> {
static double max() { return static_cast<double>(bfloat16::highest()); }
static double max() {
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
}
static double min() { return -max(); }
};

View File

@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(
// The largest negative number smaller than zero in bf16 that's not
// denormalized.
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
std::make_pair(static_cast<float>(
-std::numeric_limits<Eigen::bfloat16>::min()),
0.0f),
// Test odd and even values.
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),

View File

@ -35,14 +35,16 @@ TEST(Bfloat16Test, FlushDenormalsToZero) {
for (float denorm = -std::numeric_limits<float>::denorm_min();
denorm < std::numeric_limits<float>::denorm_min();
denorm = std::nextafterf(denorm, 1.0f)) {
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
bfloat16 bf_trunc =
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16(denorm));
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
if (std::signbit(denorm)) {
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
} else {
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
}
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
bfloat16 bf_round =
bfloat16(Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm));
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
if (std::signbit(denorm)) {
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
@ -88,7 +90,8 @@ class Bfloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<Bfloat16TestParam> {};
TEST_P(Bfloat16Test, TruncateTest) {
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
bfloat16 truncated =
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16((GetParam().input)));
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
@ -97,7 +100,8 @@ TEST_P(Bfloat16Test, TruncateTest) {
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
bfloat16 rounded = bfloat16(
Eigen::bfloat16_impl::float_to_bfloat16_rtne((GetParam().input)));
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
return;
@ -172,9 +176,13 @@ TEST(Bfloat16Test, Conversion) {
}
TEST(Bfloat16Test, Epsilon) {
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
bfloat16(1.0f)));
EXPECT_LT(1.0f,
static_cast<float>(Eigen::NumTraits<Eigen::bfloat16>::epsilon() +
bfloat16(1.0f)));
EXPECT_EQ(1.0f,
static_cast<float>((Eigen::NumTraits<Eigen::bfloat16>::epsilon() /
bfloat16(2.0f)) +
bfloat16(1.0f)));
}
TEST(Bfloat16Test, Negate) {

View File

@ -43,47 +43,17 @@ typedef Eigen::QUInt16 quint16;
} // namespace tensorflow
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
return *reinterpret_cast<tensorflow::bfloat16*>(
reinterpret_cast<uint16_t*>(&float_val));
return *reinterpret_cast<tensorflow::bfloat16*>(
reinterpret_cast<uint16_t*>(&float_val));
#else
return *reinterpret_cast<tensorflow::bfloat16*>(
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
return *reinterpret_cast<tensorflow::bfloat16*>(
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
#endif
}
namespace Eigen {
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
template <>
struct NumTraits<tensorflow::bfloat16>
: GenericNumTraits<tensorflow::bfloat16> {
enum {
IsInteger = 0,
IsSigned = 1,
RequireInitialization = 0
};
static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
return FloatToBFloat16(NumTraits<float>::highest());
}
static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
return FloatToBFloat16(NumTraits<float>::lowest());
}
static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
return FloatToBFloat16(NumTraits<float>::infinity());
}
static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
return FloatToBFloat16(NumTraits<float>::quiet_NaN());
}
};
template <>
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
enum {
@ -104,30 +74,6 @@ struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
static inline tensorflow::tstring quiet_NaN();
};
using ::tensorflow::operator==;
using ::tensorflow::operator!=;
namespace numext {
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
}
} // namespace numext
} // namespace Eigen
#if defined(_MSC_VER) && !defined(__clang__)

View File

@ -16,13 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// Note that the GPU cast functor templates need to be instantiated unlike the
// CPU ones, and hence their specializations are different than that for CPUs.
@ -278,48 +278,6 @@ template <typename From, typename To>
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
// Specialized cast op impls for bfloat16.
template <>
struct scalar_cast_op<::tensorflow::bfloat16, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
const ::tensorflow::bfloat16& a) const {
float ret;
uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
p[0] = a.value;
p[1] = 0;
#else
static_assert(::tensorflow::port::kLittleEndian,
"Not a little endian system!");
p[0] = 0;
p[1] = a.value;
#endif
return ret;
}
};
template <>
struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
};
template <>
struct scalar_cast_op<float, ::tensorflow::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef ::tensorflow::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
const float a) const {
return ::tensorflow::bfloat16(a);
}
};
template <>
struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
};
} // namespace internal
} // namespace Eigen

View File

@ -23,7 +23,6 @@ limitations under the License.
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op.h"
@ -37,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#ifdef TENSORFLOW_USE_LIBXSMM
#include "include/libxsmm_intrinsics_x86.h"
#include "include/libxsmm_malloc.h"
@ -165,7 +165,7 @@ bool IsZero(T v);
template <>
ALWAYS_INLINE bool IsZero(bfloat16 v) {
return v.IsZero();
return !static_cast<bool>(v);
}
template <>

View File

@ -17,12 +17,4 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
namespace tensorflow {
const uint16_t bfloat16::NAN_VALUE;
const uint16_t bfloat16::ZERO_VALUE;
B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
return static_cast<Eigen::half>(float(*this));
}
} // end namespace tensorflow
namespace tensorflow {} // end namespace tensorflow

View File

@ -22,479 +22,18 @@ limitations under the License.
#include <limits>
#include "tensorflow/core/platform/byte_order.h"
#if defined(__CUDACC__) || (defined(__HIPCC__) && defined(__HIP__))
// All functions callable from CUDA code must be qualified with __device__
#define B16_DEVICE_FUNC __host__ __device__
#else
#define B16_DEVICE_FUNC
#endif
#include "third_party/eigen3/Eigen/Core"
namespace Eigen {
struct bfloat16;
struct half;
}
} // namespace Eigen
namespace tensorflow {
// Single precision complex.
typedef std::complex<float> complex64;
// Double precision complex.
typedef std::complex<double> complex128;
// see framework/bfloat16.h for description.
struct bfloat16 {
// The default constructor must yield a zero value, not an uninitialized
// value; some TF kernels use T() as a zero value.
B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
bfloat16 output;
if (float_isnan(v)) {
output.value = NAN_VALUE;
return output;
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
// Flush denormal to +/- 0.
output.value = std::signbit(v) ? 0x8000 : 0;
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
output.value = p[0];
#else
output.value = p[1];
#endif
return output;
}
B16_DEVICE_FUNC explicit bfloat16(const float v) {
value = round_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
: bfloat16(static_cast<float>(val)) {}
// Following the convention of numpy, converting between complex and
// float will lead to loss of imag value.
B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
: bfloat16(val.real()) {}
B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
: bfloat16(static_cast<float>(val.real())) {}
B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
: bfloat16(static_cast<float>(val)) {}
B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
: bfloat16(static_cast<float>(val)) {}
B16_DEVICE_FUNC explicit bfloat16(const int val)
: bfloat16(static_cast<float>(val)) {}
B16_DEVICE_FUNC explicit bfloat16(const long val)
: bfloat16(static_cast<float>(val)) {}
B16_DEVICE_FUNC explicit bfloat16(const long long val)
: bfloat16(static_cast<float>(val)) {}
template <class T>
B16_DEVICE_FUNC explicit bfloat16(const T& val)
: bfloat16(static_cast<float>(val)) {}
B16_DEVICE_FUNC explicit operator float() const {
float result = 0;
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = value;
#else
q[1] = value;
#endif
return result;
}
B16_DEVICE_FUNC explicit operator bool() const {
return static_cast<bool>(float(*this));
}
B16_DEVICE_FUNC explicit operator Eigen::half() const;
B16_DEVICE_FUNC explicit operator short() const {
return static_cast<short>(float(*this));
}
B16_DEVICE_FUNC explicit operator int() const {
return static_cast<int>(float(*this));
}
B16_DEVICE_FUNC explicit operator long() const {
return static_cast<long>(float(*this));
}
B16_DEVICE_FUNC explicit operator char() const {
return static_cast<char>(float(*this));
}
B16_DEVICE_FUNC explicit operator signed char() const {
return static_cast<signed char>(float(*this));
}
B16_DEVICE_FUNC explicit operator unsigned char() const {
return static_cast<unsigned char>(float(*this));
}
B16_DEVICE_FUNC explicit operator unsigned short() const {
return static_cast<unsigned short>(float(*this));
}
B16_DEVICE_FUNC explicit operator unsigned int() const {
return static_cast<unsigned int>(float(*this));
}
B16_DEVICE_FUNC explicit operator unsigned long() const {
return static_cast<unsigned long>(float(*this));
}
B16_DEVICE_FUNC explicit operator unsigned long long() const {
return static_cast<unsigned long long>(float(*this));
}
B16_DEVICE_FUNC explicit operator long long() const {
return static_cast<long long>(float(*this));
}
B16_DEVICE_FUNC explicit operator double() const {
return static_cast<double>(float(*this));
}
B16_DEVICE_FUNC explicit operator complex64() const {
return complex64(float(*this), float(0.0));
}
B16_DEVICE_FUNC explicit operator complex128() const {
return complex128(double(*this), double(0.0));
}
union FP32 {
unsigned int u;
float f;
};
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
// method.
// TODO: There is a slightly faster implementation (8% faster on CPU)
// than this (documented in cl/175987786), that is exponentially harder to
// understand and document. Switch to the faster version when converting to
// BF16 becomes compute-bound.
B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
uint32_t input;
FP32 f;
f.f = v;
input = f.u;
bfloat16 output;
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
//
//
// Least significant bit of resulting bfloat.
uint32_t lsb = (input >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<uint16_t>(input >> 16);
if ((f.u & 0xff800000u) == 0) {
// Flush positive denormal to 0
output.value = 0x0;
}
if ((f.u & 0xff800000u) == 0x80000000u) {
// Flush negative denormal to -0
output.value = 0x8000;
}
if (float_isnan(v)) {
output.value = NAN_VALUE;
}
return output;
}
static bfloat16 epsilon() {
bfloat16 x;
x.value = 0x3c00; // 0x1.0p-7
return x;
}
static bfloat16 highest() {
bfloat16 x;
x.value = 0x7F7F; // 0x1.FEp127
return x;
}
static bfloat16 lowest() {
bfloat16 x;
x.value = 0xFF7F; // -0x1.FEp127
return x;
}
static bfloat16 min_positive_normal() {
bfloat16 x;
x.value = 0x0080; // 0x1p-126
return x;
}
bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; }
uint16_t value;
// A value that represents "not a number".
static constexpr uint16_t NAN_VALUE = 0x7FC0;
private:
// A value that represents "zero".
static constexpr uint16_t ZERO_VALUE = 0;
B16_DEVICE_FUNC static bool float_isnan(const float& x) {
#ifdef __CUDA_ARCH__
return ::isnan(x);
#else
return std::isnan(x);
#endif
}
};
B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
const bfloat16& dt) {
os << static_cast<float>(dt);
return os;
}
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
}
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
a.value ^= 0x8000;
return a;
}
B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) < static_cast<float>(b);
}
B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) <= static_cast<float>(b);
}
B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) == static_cast<float>(b);
}
B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) != static_cast<float>(b);
}
B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) > static_cast<float>(b);
}
B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
return static_cast<float>(a) >= static_cast<float>(b);
}
B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
a = a + b;
return a;
}
B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
a = a - b;
return a;
}
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
a += bfloat16(1);
return a;
}
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
a -= bfloat16(1);
return a;
}
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
bfloat16 original_value = a;
++a;
return original_value;
}
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
bfloat16 original_value = a;
--a;
return original_value;
}
B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
a = a * b;
return a;
}
B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
a = a / b;
return a;
}
typedef Eigen::bfloat16 bfloat16;
} // end namespace tensorflow
namespace std {
template <>
struct hash<tensorflow::bfloat16> {
size_t operator()(const tensorflow::bfloat16& v) const {
return hash<float>()(static_cast<float>(v));
}
};
using tensorflow::bfloat16;
inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }