Add bfloat support to XLA.

This is necessary in providing bfloat support in GPU backend.
RELNOTES: bfloat support is now added to XLA infra.
PiperOrigin-RevId: 175252067
This commit is contained in:
Yunxing Dai 2017-11-09 20:45:39 -08:00 committed by Andrew Selle
parent 685f604f63
commit 64d9aa1ace
19 changed files with 580 additions and 44 deletions

View File

@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_UINT64: case tensorflow::DT_UINT64:
*type = xla::U64; *type = xla::U64;
return Status::OK(); return Status::OK();
case tensorflow::DT_BFLOAT16:
*type = xla::BF16;
return Status::OK();
case tensorflow::DT_HALF: case tensorflow::DT_HALF:
*type = xla::F16; *type = xla::F16;
return Status::OK(); return Status::OK();

View File

@ -77,6 +77,7 @@ cc_library(
hdrs = ["types.h"], hdrs = ["types.h"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//third_party/eigen3", "//third_party/eigen3",
], ],

View File

@ -33,6 +33,20 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace {
using tensorflow::int64;
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Converts between little and big endian, assuming elements in the array are 16
// bits long.
void ConvertEndianShort(char* bytes, int64 size) {
CHECK_EQ(size / 2, 0);
for (int64 i = 0; i < size; i += 2) {
std::swap(bytes[i], bytes[i + 1]);
}
}
} // namespace
namespace xla { namespace xla {
@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size); return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
case F16: case F16:
return CopyRange<half>(src_literal, src_base, dest_base, copy_size); return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
case BF16:
return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
case F32: case F32:
return CopyRange<float>(src_literal, src_base, dest_base, copy_size); return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64: case F64:
@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int64>(0); return *Literal::CreateR0<int64>(0);
case F16: case F16:
return *Literal::CreateR0<half>(static_cast<half>(0.0f)); return *Literal::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32: case F32:
return *Literal::CreateR0<float>(0); return *Literal::CreateR0<float>(0);
case F64: case F64:
@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal,
case F16: case F16:
return *Literal::CreateR0<half>( return *Literal::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity())); static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
return *Literal::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE: case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value"; LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE: case OPAQUE:
@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal,
case F16: case F16:
return *Literal::CreateR0<half>( return *Literal::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity())); static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
return *Literal::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE: case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value"; LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE: case OPAQUE:
@ -428,6 +452,7 @@ std::unique_ptr<Literal> Literal::Transpose(
// The shape with affine layout resulting from that operation will be // The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor. // most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the // Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di // minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if // dimension has within the transposed array, a layout is affine if
@ -536,6 +561,9 @@ string Literal::GetAsString(
} }
case F16: case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index)); return tensorflow::strings::StrCat(Get<half>(multi_index));
case BF16:
return tensorflow::strings::StrCat(
static_cast<float>(Get<bfloat16>(multi_index)));
default: default:
return tensorflow::strings::StrCat( return tensorflow::strings::StrCat(
"[", PrimitiveType_Name(shape().element_type()), "]"); "[", PrimitiveType_Name(shape().element_type()), "]");
@ -743,6 +771,8 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(c64s_.data()); return reinterpret_cast<void*>(c64s_.data());
case F16: case F16:
return reinterpret_cast<void*>(f16s_.data()); return reinterpret_cast<void*>(f16s_.data());
case BF16:
return reinterpret_cast<void*>(bf16s_.data());
default: default:
LOG(FATAL) << "primitive type not supported in literals: " LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type()); << PrimitiveType_Name(shape().element_type());
@ -785,6 +815,9 @@ void Literal::Reserve(int64 num_elements) {
case F16: case F16:
Resize<half>(num_elements, static_cast<half>(0.0f)); Resize<half>(num_elements, static_cast<half>(0.0f));
break; break;
case BF16:
Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
break;
default: default:
LOG(FATAL) << "primitive type not supported in literals: " LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type()); << PrimitiveType_Name(shape().element_type());
@ -824,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F16: case F16:
actual = f16s().size() / sizeof(half); actual = f16s().size() / sizeof(half);
break; break;
case BF16:
actual = bf16s().size();
break;
default: default:
return tensorflow::errors::Unimplemented( return tensorflow::errors::Unimplemented(
"unhandled element type for literal validation: " + "unhandled element type for literal validation: " +
@ -920,6 +956,7 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64) CONVERT_IF_TYPES_MATCH(F64)
CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH #undef CONVERT_IF_TYPES_MATCH
case C64: case C64:
return ConvertToC64<primitive_src_type>(src_literal); return ConvertToC64<primitive_src_type>(src_literal);
@ -949,6 +986,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64) CONVERT_IF_DEST_TYPE_MATCHES(F64)
CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES #undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported. // Other types are not yet supported.
default: default:
@ -1019,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index); return EqualElements<double>(*this, other, 0, &multi_index);
case F16: case F16:
return EqualElements<half>(*this, other, 0, &multi_index); return EqualElements<half>(*this, other, 0, &multi_index);
case BF16:
return EqualElements<bfloat16>(*this, other, 0, &multi_index);
case C64: case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index); return EqualElements<complex64>(*this, other, 0, &multi_index);
default: default:
@ -1128,13 +1168,18 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
template <> template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() { tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
// TODO - there is an endianess problem here. fix it, or wait for uint16
// support in protobuf
auto values = mutable_f16s(); auto values = mutable_f16s();
return tensorflow::gtl::MutableArraySlice<half>(values->data(), return tensorflow::gtl::MutableArraySlice<half>(values->data(),
values->size()); values->size());
} }
template <>
tensorflow::gtl::MutableArraySlice<bfloat16>
Literal::GetMutableArraySlice<bfloat16>() {
auto values = mutable_bf16s();
return {values->data(), values->size()};
}
template <> template <>
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const { tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
CHECK_EQ(shape().element_type(), PRED); CHECK_EQ(shape().element_type(), PRED);
@ -1205,6 +1250,12 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
f16s().size() / sizeof(half)); f16s().size() / sizeof(half));
} }
template <>
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
CHECK_EQ(shape().element_type(), BF16);
return {bf16s().data(), bf16s().size()};
}
template <> template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const { const {
@ -1253,6 +1304,9 @@ bool Literal::IsAll(int8 value) const {
return AllElementsEqualValue<double>(*this, value); return AllElementsEqualValue<double>(*this, value);
case F16: case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value)); return AllElementsEqualValue<half>(*this, static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(*this,
static_cast<bfloat16>(value));
case PRED: case PRED:
if (value == 0) { if (value == 0) {
return AllElementsEqualValue<bool>(*this, false); return AllElementsEqualValue<bool>(*this, false);
@ -1274,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const {
return AllElementsEqualValue<double>(*this, value); return AllElementsEqualValue<double>(*this, value);
case F16: case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value)); return AllElementsEqualValue<half>(*this, static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(*this,
static_cast<bfloat16>(value));
default: default:
return false; return false;
} }
@ -1310,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<complex64>(indices) == complex64(0.0f, 0.0f); return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16: case F16:
return Get<half>(indices) == static_cast<half>(0.0f); return Get<half>(indices) == static_cast<half>(0.0f);
case BF16:
return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED: case PRED:
return Get<bool>(indices) == false; return Get<bool>(indices) == false;
default: default:
@ -1377,6 +1436,12 @@ void Literal::Resize<half>(int64 num_elements, half value) {
mutable_f16s()->resize(num_elements, value); mutable_f16s()->resize(num_elements, value);
} }
template <>
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_bf16s()->resize(num_elements, value);
}
template <> template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) { void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
@ -1425,6 +1490,19 @@ LiteralProto Literal::ToProto() const {
*proto.mutable_f16s() = *proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()), string(reinterpret_cast<const char*>(f16s_.data()),
f16s_.size() * sizeof(half)); f16s_.size() * sizeof(half));
if (!kLittleEndian) {
ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
proto.f16s().size());
}
break;
case BF16:
*proto.mutable_bf16s() =
string(reinterpret_cast<const char*>(bf16s_.data()),
bf16s_.size() * sizeof(bfloat16));
if (!kLittleEndian) {
ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
proto.bf16s().size());
}
break; break;
case F32: case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s()); CopyToRepeatedField(proto.mutable_f32s(), f32s());
@ -1493,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
CHECK_EQ(0, s.size() % sizeof(half)); CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half)); f16s_ = std::vector<half>(s.size() / sizeof(half));
memcpy(f16s_.data(), s.data(), s.size()); memcpy(f16s_.data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
}
break;
}
case BF16: {
const string& s(literal_proto.bf16s());
CHECK_EQ(0, s.size() % sizeof(bfloat16));
bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
memcpy(bf16s_.data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
}
break; break;
} }
case F32: case F32:

View File

@ -163,6 +163,11 @@ class Literal {
const std::vector<complex64>& c64s() const { return c64s_; } const std::vector<complex64>& c64s() const { return c64s_; }
std::vector<complex64>* mutable_c64s() { return &c64s_; } std::vector<complex64>* mutable_c64s() { return &c64s_; }
int bf16s_size() const { return bf16s().size(); }
bfloat16 bf16s(int i) const { return bf16s_[i]; }
const std::vector<bfloat16>& bf16s() const { return bf16s_; }
std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
int tuple_literals_size() const { return tuple_literals().size(); } int tuple_literals_size() const { return tuple_literals().size(); }
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
Literal* add_tuple_literals() { Literal* add_tuple_literals() {
@ -622,6 +627,7 @@ class Literal {
std::vector<uint16> u16s_; std::vector<uint16> u16s_;
std::vector<uint32> u32s_; std::vector<uint32> u32s_;
std::vector<uint64> u64s_; std::vector<uint64> u64s_;
std::vector<bfloat16> bf16s_;
std::vector<half> f16s_; std::vector<half> f16s_;
std::vector<float> f32s_; std::vector<float> f32s_;
std::vector<double> f64s_; std::vector<double> f64s_;
@ -674,6 +680,9 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
template <> template <>
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const; tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
template <>
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
template <> template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const; const;
@ -714,6 +723,9 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
template <> template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice(); tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
template <>
tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
template <> template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice(); tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
@ -747,6 +759,9 @@ void Literal::Resize<double>(int64 num_elements, double value);
template <> template <>
void Literal::Resize<half>(int64 num_elements, half value); void Literal::Resize<half>(int64 num_elements, half value);
template <>
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
template <> template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value); void Literal::Resize<complex64>(int64 num_elements, complex64 value);
@ -990,6 +1005,14 @@ inline half Literal::Get<half>(
return GetArraySlice<half>()[linear_index]; return GetArraySlice<half>()[linear_index];
} }
template <>
inline bfloat16 Literal::Get<bfloat16>(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(shape().element_type() == BF16);
int64 linear_index = LinearIndex(multi_index);
return GetArraySlice<bfloat16>()[linear_index];
}
template <typename NativeT> template <typename NativeT>
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) { NativeT value) {

View File

@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f}); auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
ASSERT_EQ("0.5", bf16_lit->ToString());
// 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
auto bf16_lit_truncated =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
auto bf16_lit_truncated2 =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
ASSERT_EQ("9", bf16_lit_truncated2->ToString());
} }
TEST_F(LiteralUtilTest, LiteralVectorToString) { TEST_F(LiteralUtilTest, LiteralVectorToString) {
@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
complex64 c8_9 = {8, 9}; complex64 c8_9 = {8, 9};
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
EXPECT_EQ(output, *expected); EXPECT_EQ(output, *expected);
} }
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
Literal output;
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h, {});
auto expected = Literal::CreateR0<bfloat16>(h);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
Literal output;
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h, {3});
auto expected = Literal::CreateR1<bfloat16>({h, h, h});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
Literal output;
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h, {2, 2});
auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output; Literal output;
output.PopulateWithValue<float>(2.5f, {}); output.PopulateWithValue<float>(2.5f, {});
@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)}, {{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}}, {half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_); }}, layout_r4_dim0major_);
auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
{bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
}}, layout_r4_dim0major_);
auto f32 = Literal::CreateR4WithLayout<float>({{ auto f32 = Literal::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s8->Convert(PRED).ConsumeValueOrDie(); conv = s8->Convert(PRED).ConsumeValueOrDie();
EXPECT_EQ(*conv, *pred); EXPECT_EQ(*conv, *pred);
conv = bf16->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *s32);
conv = bf16->Convert(F32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *f32);
conv = pred->Convert(S32).ConsumeValueOrDie(); conv = pred->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *int32_pred); EXPECT_EQ(*conv, *int32_pred);

View File

@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType<double>() {
return F64; return F64;
} }
template <>
PrimitiveType NativeToPrimitiveType<bfloat16>() {
return BF16;
}
template <> template <>
PrimitiveType NativeToPrimitiveType<half>() { PrimitiveType NativeToPrimitiveType<half>() {
return F16; return F16;
@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
} }
bool IsFloatingPointType(PrimitiveType type) { bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64; return type == F16 || type == F32 || type == F64 || type == BF16;
} }
bool IsComplexType(PrimitiveType type) { return type == C64; } bool IsComplexType(PrimitiveType type) { return type == C64; }
@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) {
case S16: case S16:
case U16: case U16:
case F16: case F16:
case BF16:
return 16; return 16;
case U32: case U32:

View File

@ -77,6 +77,8 @@ template <>
PrimitiveType NativeToPrimitiveType<double>(); PrimitiveType NativeToPrimitiveType<double>();
template <> template <>
PrimitiveType NativeToPrimitiveType<half>(); PrimitiveType NativeToPrimitiveType<half>();
template <>
PrimitiveType NativeToPrimitiveType<bfloat16>();
// Complex // Complex
template <> template <>
@ -167,6 +169,11 @@ struct PrimitiveTypeToNative<F16> {
using type = half; using type = half;
}; };
template <>
struct PrimitiveTypeToNative<BF16> {
using type = bfloat16;
};
// Complex // Complex
template <> template <>
struct PrimitiveTypeToNative<C64> { struct PrimitiveTypeToNative<C64> {

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <utility> #include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"

View File

@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"

View File

@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() {
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this); typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this); typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this); typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
});
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
}); });

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/service/hlo_runner.h"
@ -19,8 +20,6 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"

View File

@ -263,6 +263,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
case S32: case S32:
case S64: case S64:
case F16: case F16:
case BF16:
case F32: case F32:
case F64: case F64:
return true; return true;

View File

@ -116,16 +116,18 @@ template <typename FloatT, typename UnsignedT>
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs); auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs); auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) { if (ulhs != urhs) {
return ::testing::AssertionFailure() << tensorflow::strings::Printf( return ::testing::AssertionFailure() << tensorflow::strings::Printf(
"floating values are not bitwise-equal; and equality testing " "floating values are not bitwise-equal; and equality testing "
"was requested: %s=%g=%a vs %s=%g=%a", "was requested: %s=%g=%a vs %s=%g=%a",
tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
.c_str(), .c_str(),
lhs, lhs, lhs_double, lhs_double,
tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
.c_str(), .c_str(),
rhs, rhs); rhs_double, rhs_double);
} }
return ::testing::AssertionSuccess(); return ::testing::AssertionSuccess();
} }
@ -149,6 +151,10 @@ template <typename NativeT>
// Specializations for floating types that do bitwise comparisons when equality // Specializations for floating types that do bitwise comparisons when equality
// comparison is requested. // comparison is requested.
template <> template <>
::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) { ::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs); return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
} }
@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case U64: case U64:
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0); match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
break; break;
case BF16:
match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
break;
case F32: case F32:
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0); match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
break; break;

View File

@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include <vector> #include <vector>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/map_util.h"

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <complex> #include <complex>
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include <Eigen/Core> #include <Eigen/Core>
@ -32,6 +33,8 @@ using ::tensorflow::int16;
using ::tensorflow::int32; using ::tensorflow::int32;
using ::tensorflow::int64; using ::tensorflow::int64;
using ::tensorflow::bfloat16;
using ::tensorflow::uint8; using ::tensorflow::uint8;
using ::tensorflow::uint16; using ::tensorflow::uint16;
using ::tensorflow::uint32; using ::tensorflow::uint32;

View File

@ -46,6 +46,12 @@ enum PrimitiveType {
// converted to f16 from f32 at arbirary points in the computation. // converted to f16 from f32 at arbirary points in the computation.
F16 = 10; F16 = 10;
F32 = 11; F32 = 11;
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
// and 7 bits for the mantissa.
BF16 = 16;
F64 = 12; F64 = 12;
// Complex values of fixed width. // Complex values of fixed width.
@ -63,6 +69,8 @@ enum PrimitiveType {
// An opaque type used for passing context specific data to a custom // An opaque type used for passing context specific data to a custom
// operation. // operation.
OPAQUE = 14; OPAQUE = 14;
// Next = 17
} }
// Describes the value held inside padding elements. // Describes the value held inside padding elements.
@ -310,7 +318,10 @@ message LiteralProto {
repeated double f64s = 9; repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated LiteralProto tuple_literals = 10; repeated LiteralProto tuple_literals = 10;
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order // The F16s and BF16s are encoded in little endian byte order
bytes f16s = 11;
bytes bf16s = 13;
// Next = 14
} }
message WindowDimension { message WindowDimension {

View File

@ -18,17 +18,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) { void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src); for (int64 i = 0; i < size; ++i) {
uint16_t* q = reinterpret_cast<uint16_t*>(dst); dst[i] = bfloat16(src[i]);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
for (; size != 0; p += 2, q++, size--) {
*q = p[0];
} }
#else
for (; size != 0; p += 2, q++, size--) {
*q = p[1];
}
#endif
} }
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) { void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
@ -27,6 +28,97 @@ TEST(Bfloat16Test, Simple) {
EXPECT_EQ(0x4140, a.value); EXPECT_EQ(0x4140, a.value);
} }
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
uint32_t low_mantissa) {
return bit_cast<float>((sign << 31) + (exponent << 23) +
(high_mantissa << 16) + low_mantissa);
}
struct Bfloat16TestParam {
float input;
float expected;
};
class Bfloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<Bfloat16TestParam> {};
TEST_P(Bfloat16Test, RoundOrTruncate) {
bfloat16 a(GetParam().input);
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(a)));
return;
}
EXPECT_EQ(GetParam().expected, float(a));
}
INSTANTIATE_TEST_CASE_P(
Bfloat16Test_Instantiation, Bfloat16Test,
::testing::Values(
// More than half.
Bfloat16TestParam{
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
Bfloat16TestParam{
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
// Exact half.
Bfloat16TestParam{
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
// NaN stays at NaN.
Bfloat16TestParam{
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
// NaN stays at NaN -- no exponents overflow.
Bfloat16TestParam{
BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
// More than half, round to an odd number.
Bfloat16TestParam{
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
// Less than half, truncate.
Bfloat16TestParam{
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
// Less than half, truncate.
Bfloat16TestParam{
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
// Exact at half, but result is already even.
Bfloat16TestParam{
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
// Denormal values.
Bfloat16TestParam{
BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
Bfloat16TestParam{
BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
TEST(Bfloat16Test, RoundWithFractionOverflow) {
// Still works with fraction overflow -- round to 4./
//
// Input 3.9960938:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1100000000000000
//
// Should round to 4.0:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0
bfloat16 a(3.9960938f);
EXPECT_EQ(4.0, float(a));
}
TEST(Bfloat16Test, Conversion) { TEST(Bfloat16Test, Conversion) {
float a[100]; float a[100];
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {

View File

@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16;
// see framework/bfloat16.h for description. // see framework/bfloat16.h for description.
struct bfloat16 { struct bfloat16 {
EIGEN_DEVICE_FUNC bfloat16() {} EIGEN_DEVICE_FUNC bfloat16() {}
EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
uint32_t input;
memcpy(&input, &v, sizeof(uint32_t));
if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
// this makes sure after truncation we don't end up with an inf.
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
value = 0x7fc0;
} else {
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
//
//
// Least significant bit of resulting bfloat.
uint32_t lsb = (input >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
value = static_cast<uint16_t>(input >> 16);
}
}
template <class T>
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
: bfloat16(static_cast<float>(val)) {}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
float result;
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
value = p[0]; q[0] = value;
q[1] = 0;
#else #else
value = p[1]; q[0] = 0;
q[1] = value;
#endif #endif
return result;
}
EIGEN_DEVICE_FUNC explicit operator bool() const {
return static_cast<bool>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
return static_cast<Eigen::half>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator short() const {
return static_cast<short>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator int() const {
return static_cast<int>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator char() const {
return static_cast<char>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator signed char() const {
return static_cast<signed char>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
return static_cast<unsigned char>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
return static_cast<unsigned int>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
return static_cast<unsigned long>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
return static_cast<unsigned long long>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator long long() const {
return static_cast<long long>(float(*this));
}
EIGEN_DEVICE_FUNC explicit operator double() const {
return static_cast<double>(float(*this));
} }
uint16_t value; uint16_t value;
}; };
inline bool operator==(const bfloat16 a, const bfloat16 b) {
return a.value == b.value;
}
inline bool operator!=(const bfloat16 a, const bfloat16 b) {
return a.value != b.value;
}
} // end namespace tensorflow } // end namespace tensorflow
namespace Eigen { namespace Eigen {
template <> template <>
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {}; struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, using ::tensorflow::operator==;
const tensorflow::bfloat16 b) { using ::tensorflow::operator!=;
return a.value == b.value;
}
} // namespace Eigen } // namespace Eigen
#ifdef COMPILER_MSVC #ifdef COMPILER_MSVC