From bc4d2043b631b78365560fb893ac81e93935535e Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 13 Nov 2017 11:48:10 -0800 Subject: [PATCH] Add bfloat support to XLA. This is necessary in providing bfloat support in GPU backend. RELNOTES: bfloat support is now added to XLA infra. PiperOrigin-RevId: 175564791 --- tensorflow/compiler/tf2xla/type_util.cc | 3 + tensorflow/compiler/xla/BUILD | 1 + tensorflow/compiler/xla/literal_util.cc | 99 ++++++++++++++++++- tensorflow/compiler/xla/literal_util.h | 23 +++++ tensorflow/compiler/xla/literal_util_test.cc | 62 ++++++++++++ tensorflow/compiler/xla/primitive_util.cc | 8 +- tensorflow/compiler/xla/primitive_util.h | 7 ++ tensorflow/compiler/xla/service/backend.cc | 4 +- .../xla/service/cpu/cpu_runtime_test.cc | 4 +- .../compiler/xla/service/hlo_evaluator.cc | 4 + tensorflow/compiler/xla/service/hlo_runner.cc | 3 +- tensorflow/compiler/xla/shape_util.cc | 1 + .../compiler/xla/tests/literal_test_util.cc | 13 ++- .../xla/tests/local_client_test_base.cc | 3 +- tensorflow/compiler/xla/types.h | 3 + tensorflow/compiler/xla/xla_data.proto | 13 ++- tensorflow/core/framework/bfloat16_test.cc | 61 ++++++++++++ tensorflow/core/framework/numeric_types.h | 83 +++++++++++++++- 18 files changed, 374 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 1efbe0ffb17..c969212a1bf 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT64: *type = xla::U64; return Status::OK(); + case tensorflow::DT_BFLOAT16: + *type = xla::BF16; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fa4d348ebdc..515b572b0eb 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -77,6 +77,7 @@ cc_library( hdrs = ["types.h"], visibility = [":friends"], deps = [ + "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 0cb2223ae5a..93d3cd425f0 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,6 +33,20 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +namespace { +using tensorflow::int64; + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian, assuming elements in the array are 16 +// bits long. +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} +} // namespace namespace xla { @@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: return CopyRange(src_literal, src_base, dest_base, copy_size); + case BF16: + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: @@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F16: return *Literal::CreateR0(static_cast(0.0f)); + case BF16: + return *Literal::CreateR0(static_cast(0.0f)); case F32: return *Literal::CreateR0(0); case F64: @@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -428,6 +452,7 @@ std::unique_ptr Literal::Transpose( // The shape with affine layout resulting from that operation will be // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. + // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if @@ -536,6 +561,9 @@ string Literal::GetAsString( } case F16: return tensorflow::strings::StrCat(Get(multi_index)); + case BF16: + return tensorflow::strings::StrCat( + static_cast(Get(multi_index))); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(shape().element_type()), "]"); @@ -743,6 +771,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); + case BF16: + return reinterpret_cast(bf16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -785,6 +815,9 @@ void Literal::Reserve(int64 num_elements) { case F16: Resize(num_elements, static_cast(0.0f)); break; + case BF16: + Resize(num_elements, static_cast(0.0f)); + break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -824,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F16: actual = f16s().size() / sizeof(half); break; + case BF16: + actual = bf16s().size(); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -920,6 +956,7 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: return ConvertToC64(src_literal); @@ -949,8 +986,9 @@ StatusOr> Literal::Convert( CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) #undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. + // Other types are not yet supported. default: return InvalidArgument("Unimplemented: Convert from type %s to type %s", PrimitiveType_Name(shape().element_type()).c_str(), @@ -1019,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case BF16: + return EqualElements(*this, other, 0, &multi_index); case C64: return EqualElements(*this, other, 0, &multi_index); default: @@ -1128,13 +1168,18 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // TODO - there is an endianess problem here. fix it, or wait for uint16 - // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } +template <> +tensorflow::gtl::MutableArraySlice +Literal::GetMutableArraySlice() { + auto values = mutable_bf16s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), PRED); @@ -1205,6 +1250,12 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), BF16); + return {bf16s().data(), bf16s().size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { @@ -1253,6 +1304,9 @@ bool Literal::IsAll(int8 value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(*this, false); @@ -1274,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); default: return false; } @@ -1310,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); case PRED: return Get(indices) == false; default: @@ -1377,6 +1436,12 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, bfloat16 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_bf16s()->resize(num_elements, value); +} + template <> void Literal::Resize(int64 num_elements, complex64 value) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); @@ -1425,6 +1490,19 @@ LiteralProto Literal::ToProto() const { *proto.mutable_f16s() = string(reinterpret_cast(f16s_.data()), f16s_.size() * sizeof(half)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), + proto.f16s().size()); + } + break; + case BF16: + *proto.mutable_bf16s() = + string(reinterpret_cast(bf16s_.data()), + bf16s_.size() * sizeof(bfloat16)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), + proto.bf16s().size()); + } break; case F32: CopyToRepeatedField(proto.mutable_f32s(), f32s()); @@ -1493,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { CHECK_EQ(0, s.size() % sizeof(half)); f16s_ = std::vector(s.size() / sizeof(half)); memcpy(f16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + } + break; + } + case BF16: { + const string& s(literal_proto.bf16s()); + CHECK_EQ(0, s.size() % sizeof(bfloat16)); + bf16s_ = std::vector(s.size() / sizeof(bfloat16)); + memcpy(bf16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + } break; } case F32: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 667f926c464..f37e529caf5 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -163,6 +163,11 @@ class Literal { const std::vector& c64s() const { return c64s_; } std::vector* mutable_c64s() { return &c64s_; } + int bf16s_size() const { return bf16s().size(); } + bfloat16 bf16s(int i) const { return bf16s_[i]; } + const std::vector& bf16s() const { return bf16s_; } + std::vector* mutable_bf16s() { return &bf16s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -622,6 +627,7 @@ class Literal { std::vector u16s_; std::vector u32s_; std::vector u64s_; + std::vector bf16s_; std::vector f16s_; std::vector f32s_; std::vector f64s_; @@ -674,6 +680,9 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; @@ -714,6 +723,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -747,6 +759,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, bfloat16 value); + template <> void Literal::Resize(int64 num_elements, complex64 value); @@ -990,6 +1005,14 @@ inline half Literal::Get( return GetArraySlice()[linear_index]; } +template <> +inline bfloat16 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == BF16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; +} + template void Literal::Set(tensorflow::gtl::ArraySlice multi_index, NativeT value) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 6d596da4ada..816bb3c549e 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + + auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", bf16_lit->ToString()); + + // 3.14 will be truncated to 3.125 in bfloat16 format. + auto bf16_lit_truncated = + Literal::CreateR0(static_cast(3.14f)); + ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + + auto bf16_lit_truncated2 = + Literal::CreateR0(static_cast(9.001f)); + ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + bfloat16 b8(8.0f); + bfloat16 b9(9.0f); + + EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + + // 9.001 will be truncated to 9.0 + bfloat16 b91(9.001f); + bfloat16 b90(9.00f); + EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); @@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { + Literal output; + bfloat16 h(0.25f); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { + Literal output; + bfloat16 h(0.5f); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { + Literal output; + bfloat16 h(2.0f); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); + auto bf16 = Literal::CreateR4WithLayout({{ + {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, + {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, + {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}}, + {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, + }}, layout_r4_dim0major_); auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, @@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = s8->Convert(PRED).ConsumeValueOrDie(); EXPECT_EQ(*conv, *pred); + conv = bf16->Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *s32); + + conv = bf16->Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *f32); + conv = pred->Convert(S32).ConsumeValueOrDie(); EXPECT_EQ(*conv, *int32_pred); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2113b5e06f3..2bce56b7bd2 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return BF16; +} + template <> PrimitiveType NativeToPrimitiveType() { return F16; @@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType() { } bool IsFloatingPointType(PrimitiveType type) { - return type == F16 || type == F32 || type == F64; + return type == F16 || type == F32 || type == F64 || type == BF16; } bool IsComplexType(PrimitiveType type) { return type == C64; } @@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) { case S16: case U16: case F16: + case BF16: return 16; case U32: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index a49c8b86fcf..19c6a138885 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -77,6 +77,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); // Complex template <> @@ -167,6 +169,11 @@ struct PrimitiveTypeToNative { using type = half; }; +template <> +struct PrimitiveTypeToNative { + using type = bfloat16; +}; + // Complex template <> struct PrimitiveTypeToNative { diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 9abe30e3f37..05f2d062784 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/compiler/xla/service/backend.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f8e260dd901..f385829cdf5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 88b77ccdd03..a722d1b3d99 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); + + typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { + return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); + }); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index f463e57d995..158fb9a546c 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/hlo_runner.h" @@ -19,8 +20,6 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b5eb81dfc6a..4d0bafa9087 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -263,6 +263,7 @@ StatusOr MakeShapeWithLayoutInternal( case S32: case S64: case F16: + case BF16: case F32: case F64: return true; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 95a52ecd2f5..75c9a0d3fb5 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -116,16 +116,18 @@ template ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); if (ulhs != urhs) { return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) .c_str(), - lhs, lhs, + lhs_double, lhs_double, tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) .c_str(), - rhs, rhs); + rhs_double, rhs_double); } return ::testing::AssertionSuccess(); } @@ -149,6 +151,10 @@ template // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> +::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case U64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case BF16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c11e1df0a78..d98875dbc20 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 3b19ca321ca..9fa4297523b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" #include @@ -32,6 +33,8 @@ using ::tensorflow::int16; using ::tensorflow::int32; using ::tensorflow::int64; +using ::tensorflow::bfloat16; + using ::tensorflow::uint8; using ::tensorflow::uint16; using ::tensorflow::uint32; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 71466047080..eac8f2ff07e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -46,6 +46,12 @@ enum PrimitiveType { // converted to f16 from f32 at arbirary points in the computation. F16 = 10; F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + F64 = 12; // Complex values of fixed width. @@ -63,6 +69,8 @@ enum PrimitiveType { // An opaque type used for passing context specific data to a custom // operation. OPAQUE = 14; + + // Next = 17 } // Describes the value held inside padding elements. @@ -310,7 +318,10 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - bytes f16s = 11; // Note: the F16s are encoded in little endian byte order + // The F16s and BF16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + // Next = 14 } message WindowDimension { diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index af4e6a44116..6e453387516 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -27,6 +28,66 @@ TEST(Bfloat16Test, Simple) { EXPECT_EQ(0x4140, a.value); } +float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, + uint32_t low_mantissa) { + return bit_cast((sign << 31) + (exponent << 23) + + (high_mantissa << 16) + low_mantissa); +} + +struct Bfloat16TestParam { + float input; + float expected; +}; + +class Bfloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(Bfloat16Test, TruncateTest) { + bfloat16 a(GetParam().input); + if (std::isnan(GetParam().input)) { + EXPECT_TRUE(std::isnan(float(a)) || std::isinf(float(a))); + return; + } + EXPECT_EQ(GetParam().expected, float(a)); +} + +INSTANTIATE_TEST_CASE_P( + Bfloat16Test_Instantiation, Bfloat16Test, + ::testing::Values( + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001), + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111), + BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000), + BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000)})); + TEST(Bfloat16Test, Conversion) { float a[100]; for (int i = 0; i < 100; ++i) { diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index a630bee38d8..2b080e13fdb 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -44,6 +44,7 @@ typedef Eigen::QUInt16 quint16; // see framework/bfloat16.h for description. struct bfloat16 { EIGEN_DEVICE_FUNC bfloat16() {} + EIGEN_DEVICE_FUNC explicit bfloat16(const float v) { const uint16_t* p = reinterpret_cast(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -53,20 +54,92 @@ struct bfloat16 { #endif } + template + explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) + : bfloat16(static_cast(val)) {} + + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + float result; + + uint16_t* q = reinterpret_cast(&result); + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = value; + q[1] = 0; +#else + q[0] = 0; + q[1] = value; +#endif + return result; + } + + EIGEN_DEVICE_FUNC explicit operator bool() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator Eigen::half() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator short() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator int() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator signed char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned int() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator long long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator double() const { + return static_cast(float(*this)); + } + uint16_t value; }; +inline bool operator==(const bfloat16 a, const bfloat16 b) { + return a.value == b.value; +} + +inline bool operator!=(const bfloat16 a, const bfloat16 b) { + return a.value != b.value; +} + } // end namespace tensorflow namespace Eigen { template <> struct NumTraits : GenericNumTraits {}; -EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, - const tensorflow::bfloat16 b) { - return a.value == b.value; -} - +using ::tensorflow::operator==; +using ::tensorflow::operator!=; } // namespace Eigen #ifdef COMPILER_MSVC