Add bfloat support to XLA.
This is necessary in providing bfloat support in GPU backend. RELNOTES: bfloat support is now added to XLA infra. PiperOrigin-RevId: 175252067
This commit is contained in:
parent
685f604f63
commit
64d9aa1ace
@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
|
|||||||
case tensorflow::DT_UINT64:
|
case tensorflow::DT_UINT64:
|
||||||
*type = xla::U64;
|
*type = xla::U64;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
case tensorflow::DT_BFLOAT16:
|
||||||
|
*type = xla::BF16;
|
||||||
|
return Status::OK();
|
||||||
case tensorflow::DT_HALF:
|
case tensorflow::DT_HALF:
|
||||||
*type = xla::F16;
|
*type = xla::F16;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -77,6 +77,7 @@ cc_library(
|
|||||||
hdrs = ["types.h"],
|
hdrs = ["types.h"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
|
@ -33,6 +33,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
namespace {
|
||||||
|
using tensorflow::int64;
|
||||||
|
|
||||||
|
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
|
||||||
|
|
||||||
|
// Converts between little and big endian, assuming elements in the array are 16
|
||||||
|
// bits long.
|
||||||
|
void ConvertEndianShort(char* bytes, int64 size) {
|
||||||
|
CHECK_EQ(size / 2, 0);
|
||||||
|
for (int64 i = 0; i < size; i += 2) {
|
||||||
|
std::swap(bytes[i], bytes[i + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
|
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
|
||||||
case F16:
|
case F16:
|
||||||
return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
|
return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
|
||||||
|
case BF16:
|
||||||
|
return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
|
||||||
case F32:
|
case F32:
|
||||||
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
|
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
|
||||||
case F64:
|
case F64:
|
||||||
@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
return *Literal::CreateR0<int64>(0);
|
return *Literal::CreateR0<int64>(0);
|
||||||
case F16:
|
case F16:
|
||||||
return *Literal::CreateR0<half>(static_cast<half>(0.0f));
|
return *Literal::CreateR0<half>(static_cast<half>(0.0f));
|
||||||
|
case BF16:
|
||||||
|
return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
|
||||||
case F32:
|
case F32:
|
||||||
return *Literal::CreateR0<float>(0);
|
return *Literal::CreateR0<float>(0);
|
||||||
case F64:
|
case F64:
|
||||||
@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
case F16:
|
case F16:
|
||||||
return *Literal::CreateR0<half>(
|
return *Literal::CreateR0<half>(
|
||||||
static_cast<half>(-std::numeric_limits<float>::infinity()));
|
static_cast<half>(-std::numeric_limits<float>::infinity()));
|
||||||
|
case BF16:
|
||||||
|
return *Literal::CreateR0<bfloat16>(
|
||||||
|
static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
LOG(FATAL) << "tuple element type has no minimum value";
|
LOG(FATAL) << "tuple element type has no minimum value";
|
||||||
case OPAQUE:
|
case OPAQUE:
|
||||||
@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
case F16:
|
case F16:
|
||||||
return *Literal::CreateR0<half>(
|
return *Literal::CreateR0<half>(
|
||||||
static_cast<half>(std::numeric_limits<float>::infinity()));
|
static_cast<half>(std::numeric_limits<float>::infinity()));
|
||||||
|
case BF16:
|
||||||
|
return *Literal::CreateR0<bfloat16>(
|
||||||
|
static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
LOG(FATAL) << "tuple element type has no maximum value";
|
LOG(FATAL) << "tuple element type has no maximum value";
|
||||||
case OPAQUE:
|
case OPAQUE:
|
||||||
@ -428,6 +452,7 @@ std::unique_ptr<Literal> Literal::Transpose(
|
|||||||
// The shape with affine layout resulting from that operation will be
|
// The shape with affine layout resulting from that operation will be
|
||||||
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
|
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
|
||||||
// most minor.
|
// most minor.
|
||||||
|
//
|
||||||
// Essentially, given MinMaj(Di) the position of the Di dimension within the
|
// Essentially, given MinMaj(Di) the position of the Di dimension within the
|
||||||
// minor to major vector, and given T(Di) the index that the original Di
|
// minor to major vector, and given T(Di) the index that the original Di
|
||||||
// dimension has within the transposed array, a layout is affine if
|
// dimension has within the transposed array, a layout is affine if
|
||||||
@ -536,6 +561,9 @@ string Literal::GetAsString(
|
|||||||
}
|
}
|
||||||
case F16:
|
case F16:
|
||||||
return tensorflow::strings::StrCat(Get<half>(multi_index));
|
return tensorflow::strings::StrCat(Get<half>(multi_index));
|
||||||
|
case BF16:
|
||||||
|
return tensorflow::strings::StrCat(
|
||||||
|
static_cast<float>(Get<bfloat16>(multi_index)));
|
||||||
default:
|
default:
|
||||||
return tensorflow::strings::StrCat(
|
return tensorflow::strings::StrCat(
|
||||||
"[", PrimitiveType_Name(shape().element_type()), "]");
|
"[", PrimitiveType_Name(shape().element_type()), "]");
|
||||||
@ -743,6 +771,8 @@ void* Literal::MutableInternalData() {
|
|||||||
return reinterpret_cast<void*>(c64s_.data());
|
return reinterpret_cast<void*>(c64s_.data());
|
||||||
case F16:
|
case F16:
|
||||||
return reinterpret_cast<void*>(f16s_.data());
|
return reinterpret_cast<void*>(f16s_.data());
|
||||||
|
case BF16:
|
||||||
|
return reinterpret_cast<void*>(bf16s_.data());
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "primitive type not supported in literals: "
|
LOG(FATAL) << "primitive type not supported in literals: "
|
||||||
<< PrimitiveType_Name(shape().element_type());
|
<< PrimitiveType_Name(shape().element_type());
|
||||||
@ -785,6 +815,9 @@ void Literal::Reserve(int64 num_elements) {
|
|||||||
case F16:
|
case F16:
|
||||||
Resize<half>(num_elements, static_cast<half>(0.0f));
|
Resize<half>(num_elements, static_cast<half>(0.0f));
|
||||||
break;
|
break;
|
||||||
|
case BF16:
|
||||||
|
Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "primitive type not supported in literals: "
|
LOG(FATAL) << "primitive type not supported in literals: "
|
||||||
<< PrimitiveType_Name(shape().element_type());
|
<< PrimitiveType_Name(shape().element_type());
|
||||||
@ -824,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
|
|||||||
case F16:
|
case F16:
|
||||||
actual = f16s().size() / sizeof(half);
|
actual = f16s().size() / sizeof(half);
|
||||||
break;
|
break;
|
||||||
|
case BF16:
|
||||||
|
actual = bf16s().size();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return tensorflow::errors::Unimplemented(
|
return tensorflow::errors::Unimplemented(
|
||||||
"unhandled element type for literal validation: " +
|
"unhandled element type for literal validation: " +
|
||||||
@ -920,6 +956,7 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
|
|||||||
CONVERT_IF_TYPES_MATCH(F16)
|
CONVERT_IF_TYPES_MATCH(F16)
|
||||||
CONVERT_IF_TYPES_MATCH(F32)
|
CONVERT_IF_TYPES_MATCH(F32)
|
||||||
CONVERT_IF_TYPES_MATCH(F64)
|
CONVERT_IF_TYPES_MATCH(F64)
|
||||||
|
CONVERT_IF_TYPES_MATCH(BF16)
|
||||||
#undef CONVERT_IF_TYPES_MATCH
|
#undef CONVERT_IF_TYPES_MATCH
|
||||||
case C64:
|
case C64:
|
||||||
return ConvertToC64<primitive_src_type>(src_literal);
|
return ConvertToC64<primitive_src_type>(src_literal);
|
||||||
@ -949,6 +986,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
|
|||||||
CONVERT_IF_DEST_TYPE_MATCHES(F16)
|
CONVERT_IF_DEST_TYPE_MATCHES(F16)
|
||||||
CONVERT_IF_DEST_TYPE_MATCHES(F32)
|
CONVERT_IF_DEST_TYPE_MATCHES(F32)
|
||||||
CONVERT_IF_DEST_TYPE_MATCHES(F64)
|
CONVERT_IF_DEST_TYPE_MATCHES(F64)
|
||||||
|
CONVERT_IF_DEST_TYPE_MATCHES(BF16)
|
||||||
#undef CONVERT_IF_DEST_TYPE_MATCHES
|
#undef CONVERT_IF_DEST_TYPE_MATCHES
|
||||||
// Other types are not yet supported.
|
// Other types are not yet supported.
|
||||||
default:
|
default:
|
||||||
@ -1019,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const {
|
|||||||
return EqualElements<double>(*this, other, 0, &multi_index);
|
return EqualElements<double>(*this, other, 0, &multi_index);
|
||||||
case F16:
|
case F16:
|
||||||
return EqualElements<half>(*this, other, 0, &multi_index);
|
return EqualElements<half>(*this, other, 0, &multi_index);
|
||||||
|
case BF16:
|
||||||
|
return EqualElements<bfloat16>(*this, other, 0, &multi_index);
|
||||||
case C64:
|
case C64:
|
||||||
return EqualElements<complex64>(*this, other, 0, &multi_index);
|
return EqualElements<complex64>(*this, other, 0, &multi_index);
|
||||||
default:
|
default:
|
||||||
@ -1128,13 +1168,18 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
|
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
|
||||||
// TODO - there is an endianess problem here. fix it, or wait for uint16
|
|
||||||
// support in protobuf
|
|
||||||
auto values = mutable_f16s();
|
auto values = mutable_f16s();
|
||||||
return tensorflow::gtl::MutableArraySlice<half>(values->data(),
|
return tensorflow::gtl::MutableArraySlice<half>(values->data(),
|
||||||
values->size());
|
values->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensorflow::gtl::MutableArraySlice<bfloat16>
|
||||||
|
Literal::GetMutableArraySlice<bfloat16>() {
|
||||||
|
auto values = mutable_bf16s();
|
||||||
|
return {values->data(), values->size()};
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
|
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
|
||||||
CHECK_EQ(shape().element_type(), PRED);
|
CHECK_EQ(shape().element_type(), PRED);
|
||||||
@ -1205,6 +1250,12 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
|
|||||||
f16s().size() / sizeof(half));
|
f16s().size() / sizeof(half));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
|
||||||
|
CHECK_EQ(shape().element_type(), BF16);
|
||||||
|
return {bf16s().data(), bf16s().size()};
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
||||||
const {
|
const {
|
||||||
@ -1253,6 +1304,9 @@ bool Literal::IsAll(int8 value) const {
|
|||||||
return AllElementsEqualValue<double>(*this, value);
|
return AllElementsEqualValue<double>(*this, value);
|
||||||
case F16:
|
case F16:
|
||||||
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
|
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
|
||||||
|
case BF16:
|
||||||
|
return AllElementsEqualValue<bfloat16>(*this,
|
||||||
|
static_cast<bfloat16>(value));
|
||||||
case PRED:
|
case PRED:
|
||||||
if (value == 0) {
|
if (value == 0) {
|
||||||
return AllElementsEqualValue<bool>(*this, false);
|
return AllElementsEqualValue<bool>(*this, false);
|
||||||
@ -1274,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const {
|
|||||||
return AllElementsEqualValue<double>(*this, value);
|
return AllElementsEqualValue<double>(*this, value);
|
||||||
case F16:
|
case F16:
|
||||||
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
|
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
|
||||||
|
case BF16:
|
||||||
|
return AllElementsEqualValue<bfloat16>(*this,
|
||||||
|
static_cast<bfloat16>(value));
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1310,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
|
|||||||
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
|
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
|
||||||
case F16:
|
case F16:
|
||||||
return Get<half>(indices) == static_cast<half>(0.0f);
|
return Get<half>(indices) == static_cast<half>(0.0f);
|
||||||
|
case BF16:
|
||||||
|
return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
|
||||||
case PRED:
|
case PRED:
|
||||||
return Get<bool>(indices) == false;
|
return Get<bool>(indices) == false;
|
||||||
default:
|
default:
|
||||||
@ -1377,6 +1436,12 @@ void Literal::Resize<half>(int64 num_elements, half value) {
|
|||||||
mutable_f16s()->resize(num_elements, value);
|
mutable_f16s()->resize(num_elements, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
|
||||||
|
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
|
||||||
|
mutable_bf16s()->resize(num_elements, value);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
|
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
|
||||||
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
|
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
|
||||||
@ -1425,6 +1490,19 @@ LiteralProto Literal::ToProto() const {
|
|||||||
*proto.mutable_f16s() =
|
*proto.mutable_f16s() =
|
||||||
string(reinterpret_cast<const char*>(f16s_.data()),
|
string(reinterpret_cast<const char*>(f16s_.data()),
|
||||||
f16s_.size() * sizeof(half));
|
f16s_.size() * sizeof(half));
|
||||||
|
if (!kLittleEndian) {
|
||||||
|
ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
|
||||||
|
proto.f16s().size());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BF16:
|
||||||
|
*proto.mutable_bf16s() =
|
||||||
|
string(reinterpret_cast<const char*>(bf16s_.data()),
|
||||||
|
bf16s_.size() * sizeof(bfloat16));
|
||||||
|
if (!kLittleEndian) {
|
||||||
|
ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
|
||||||
|
proto.bf16s().size());
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case F32:
|
case F32:
|
||||||
CopyToRepeatedField(proto.mutable_f32s(), f32s());
|
CopyToRepeatedField(proto.mutable_f32s(), f32s());
|
||||||
@ -1493,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
|||||||
CHECK_EQ(0, s.size() % sizeof(half));
|
CHECK_EQ(0, s.size() % sizeof(half));
|
||||||
f16s_ = std::vector<half>(s.size() / sizeof(half));
|
f16s_ = std::vector<half>(s.size() / sizeof(half));
|
||||||
memcpy(f16s_.data(), s.data(), s.size());
|
memcpy(f16s_.data(), s.data(), s.size());
|
||||||
|
|
||||||
|
if (!kLittleEndian) {
|
||||||
|
ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BF16: {
|
||||||
|
const string& s(literal_proto.bf16s());
|
||||||
|
CHECK_EQ(0, s.size() % sizeof(bfloat16));
|
||||||
|
bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
|
||||||
|
memcpy(bf16s_.data(), s.data(), s.size());
|
||||||
|
|
||||||
|
if (!kLittleEndian) {
|
||||||
|
ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case F32:
|
case F32:
|
||||||
|
@ -163,6 +163,11 @@ class Literal {
|
|||||||
const std::vector<complex64>& c64s() const { return c64s_; }
|
const std::vector<complex64>& c64s() const { return c64s_; }
|
||||||
std::vector<complex64>* mutable_c64s() { return &c64s_; }
|
std::vector<complex64>* mutable_c64s() { return &c64s_; }
|
||||||
|
|
||||||
|
int bf16s_size() const { return bf16s().size(); }
|
||||||
|
bfloat16 bf16s(int i) const { return bf16s_[i]; }
|
||||||
|
const std::vector<bfloat16>& bf16s() const { return bf16s_; }
|
||||||
|
std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
|
||||||
|
|
||||||
int tuple_literals_size() const { return tuple_literals().size(); }
|
int tuple_literals_size() const { return tuple_literals().size(); }
|
||||||
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
|
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
|
||||||
Literal* add_tuple_literals() {
|
Literal* add_tuple_literals() {
|
||||||
@ -622,6 +627,7 @@ class Literal {
|
|||||||
std::vector<uint16> u16s_;
|
std::vector<uint16> u16s_;
|
||||||
std::vector<uint32> u32s_;
|
std::vector<uint32> u32s_;
|
||||||
std::vector<uint64> u64s_;
|
std::vector<uint64> u64s_;
|
||||||
|
std::vector<bfloat16> bf16s_;
|
||||||
std::vector<half> f16s_;
|
std::vector<half> f16s_;
|
||||||
std::vector<float> f32s_;
|
std::vector<float> f32s_;
|
||||||
std::vector<double> f64s_;
|
std::vector<double> f64s_;
|
||||||
@ -674,6 +680,9 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
|
|||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
|
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
||||||
const;
|
const;
|
||||||
@ -714,6 +723,9 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
|
|||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
|
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
|
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
|
||||||
|
|
||||||
@ -747,6 +759,9 @@ void Literal::Resize<double>(int64 num_elements, double value);
|
|||||||
template <>
|
template <>
|
||||||
void Literal::Resize<half>(int64 num_elements, half value);
|
void Literal::Resize<half>(int64 num_elements, half value);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
|
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
|
||||||
|
|
||||||
@ -990,6 +1005,14 @@ inline half Literal::Get<half>(
|
|||||||
return GetArraySlice<half>()[linear_index];
|
return GetArraySlice<half>()[linear_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline bfloat16 Literal::Get<bfloat16>(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||||
|
CHECK(shape().element_type() == BF16);
|
||||||
|
int64 linear_index = LinearIndex(multi_index);
|
||||||
|
return GetArraySlice<bfloat16>()[linear_index];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||||
NativeT value) {
|
NativeT value) {
|
||||||
|
@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
|||||||
|
|
||||||
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
|
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
|
||||||
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
|
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
|
||||||
|
|
||||||
|
auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
|
||||||
|
ASSERT_EQ("0.5", bf16_lit->ToString());
|
||||||
|
|
||||||
|
// 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
|
||||||
|
auto bf16_lit_truncated =
|
||||||
|
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
|
||||||
|
ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
|
||||||
|
|
||||||
|
auto bf16_lit_truncated2 =
|
||||||
|
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
|
||||||
|
ASSERT_EQ("9", bf16_lit_truncated2->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||||
@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) {
|
|||||||
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
|
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
|
||||||
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
|
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
|
||||||
|
|
||||||
|
bfloat16 b8(8.0f);
|
||||||
|
bfloat16 b9(9.0f);
|
||||||
|
|
||||||
|
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
|
||||||
|
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
|
||||||
|
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
|
||||||
|
|
||||||
|
// 9.001 will be truncated to 9.0
|
||||||
|
bfloat16 b91(9.001f);
|
||||||
|
bfloat16 b90(9.00f);
|
||||||
|
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
|
||||||
|
|
||||||
complex64 c8_9 = {8, 9};
|
complex64 c8_9 = {8, 9};
|
||||||
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
|
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
|
||||||
|
|
||||||
@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
|
|||||||
EXPECT_EQ(output, *expected);
|
EXPECT_EQ(output, *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
|
||||||
|
Literal output;
|
||||||
|
bfloat16 h(0.25f);
|
||||||
|
output.PopulateWithValue<bfloat16>(h, {});
|
||||||
|
auto expected = Literal::CreateR0<bfloat16>(h);
|
||||||
|
EXPECT_EQ(output, *expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
|
||||||
|
Literal output;
|
||||||
|
bfloat16 h(0.5f);
|
||||||
|
output.PopulateWithValue<bfloat16>(h, {3});
|
||||||
|
auto expected = Literal::CreateR1<bfloat16>({h, h, h});
|
||||||
|
EXPECT_EQ(output, *expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
|
||||||
|
Literal output;
|
||||||
|
bfloat16 h(2.0f);
|
||||||
|
output.PopulateWithValue<bfloat16>(h, {2, 2});
|
||||||
|
auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
|
||||||
|
EXPECT_EQ(output, *expected);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
|
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
|
||||||
Literal output;
|
Literal output;
|
||||||
output.PopulateWithValue<float>(2.5f, {});
|
output.PopulateWithValue<float>(2.5f, {});
|
||||||
@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
|||||||
{{half(26.0), half(0.0), half(28.0), half(0.0)},
|
{{half(26.0), half(0.0), half(28.0), half(0.0)},
|
||||||
{half(0.0), half(31.0), half(0.0), half(33.0)}},
|
{half(0.0), half(31.0), half(0.0), half(33.0)}},
|
||||||
}}, layout_r4_dim0major_);
|
}}, layout_r4_dim0major_);
|
||||||
|
auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
|
||||||
|
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
|
||||||
|
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
|
||||||
|
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
|
||||||
|
{bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
|
||||||
|
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
|
||||||
|
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
|
||||||
|
}}, layout_r4_dim0major_);
|
||||||
auto f32 = Literal::CreateR4WithLayout<float>({{
|
auto f32 = Literal::CreateR4WithLayout<float>({{
|
||||||
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
|
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
|
||||||
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
||||||
@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
|||||||
conv = s8->Convert(PRED).ConsumeValueOrDie();
|
conv = s8->Convert(PRED).ConsumeValueOrDie();
|
||||||
EXPECT_EQ(*conv, *pred);
|
EXPECT_EQ(*conv, *pred);
|
||||||
|
|
||||||
|
conv = bf16->Convert(S32).ConsumeValueOrDie();
|
||||||
|
EXPECT_EQ(*conv, *s32);
|
||||||
|
|
||||||
|
conv = bf16->Convert(F32).ConsumeValueOrDie();
|
||||||
|
EXPECT_EQ(*conv, *f32);
|
||||||
|
|
||||||
conv = pred->Convert(S32).ConsumeValueOrDie();
|
conv = pred->Convert(S32).ConsumeValueOrDie();
|
||||||
EXPECT_EQ(*conv, *int32_pred);
|
EXPECT_EQ(*conv, *int32_pred);
|
||||||
|
|
||||||
|
@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType<double>() {
|
|||||||
return F64;
|
return F64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
PrimitiveType NativeToPrimitiveType<bfloat16>() {
|
||||||
|
return BF16;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
PrimitiveType NativeToPrimitiveType<half>() {
|
PrimitiveType NativeToPrimitiveType<half>() {
|
||||||
return F16;
|
return F16;
|
||||||
@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool IsFloatingPointType(PrimitiveType type) {
|
bool IsFloatingPointType(PrimitiveType type) {
|
||||||
return type == F16 || type == F32 || type == F64;
|
return type == F16 || type == F32 || type == F64 || type == BF16;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsComplexType(PrimitiveType type) { return type == C64; }
|
bool IsComplexType(PrimitiveType type) { return type == C64; }
|
||||||
@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) {
|
|||||||
case S16:
|
case S16:
|
||||||
case U16:
|
case U16:
|
||||||
case F16:
|
case F16:
|
||||||
|
case BF16:
|
||||||
return 16;
|
return 16;
|
||||||
|
|
||||||
case U32:
|
case U32:
|
||||||
|
@ -77,6 +77,8 @@ template <>
|
|||||||
PrimitiveType NativeToPrimitiveType<double>();
|
PrimitiveType NativeToPrimitiveType<double>();
|
||||||
template <>
|
template <>
|
||||||
PrimitiveType NativeToPrimitiveType<half>();
|
PrimitiveType NativeToPrimitiveType<half>();
|
||||||
|
template <>
|
||||||
|
PrimitiveType NativeToPrimitiveType<bfloat16>();
|
||||||
|
|
||||||
// Complex
|
// Complex
|
||||||
template <>
|
template <>
|
||||||
@ -167,6 +169,11 @@ struct PrimitiveTypeToNative<F16> {
|
|||||||
using type = half;
|
using type = half;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PrimitiveTypeToNative<BF16> {
|
||||||
|
using type = bfloat16;
|
||||||
|
};
|
||||||
|
|
||||||
// Complex
|
// Complex
|
||||||
template <>
|
template <>
|
||||||
struct PrimitiveTypeToNative<C64> {
|
struct PrimitiveTypeToNative<C64> {
|
||||||
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
|
@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() {
|
|||||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||||
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
|
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
|
||||||
|
|
||||||
|
typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
|
return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
|
||||||
|
});
|
||||||
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
|
||||||
});
|
});
|
||||||
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||||
|
|
||||||
@ -19,8 +20,6 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
|
@ -263,6 +263,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|||||||
case S32:
|
case S32:
|
||||||
case S64:
|
case S64:
|
||||||
case F16:
|
case F16:
|
||||||
|
case BF16:
|
||||||
case F32:
|
case F32:
|
||||||
case F64:
|
case F64:
|
||||||
return true;
|
return true;
|
||||||
|
@ -116,16 +116,18 @@ template <typename FloatT, typename UnsignedT>
|
|||||||
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
|
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
|
||||||
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
|
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
|
||||||
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
|
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
|
||||||
|
auto lhs_double = static_cast<double>(lhs);
|
||||||
|
auto rhs_double = static_cast<double>(rhs);
|
||||||
if (ulhs != urhs) {
|
if (ulhs != urhs) {
|
||||||
return ::testing::AssertionFailure() << tensorflow::strings::Printf(
|
return ::testing::AssertionFailure() << tensorflow::strings::Printf(
|
||||||
"floating values are not bitwise-equal; and equality testing "
|
"floating values are not bitwise-equal; and equality testing "
|
||||||
"was requested: %s=%g=%a vs %s=%g=%a",
|
"was requested: %s=%g=%a vs %s=%g=%a",
|
||||||
tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
|
tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
|
||||||
.c_str(),
|
.c_str(),
|
||||||
lhs, lhs,
|
lhs_double, lhs_double,
|
||||||
tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
|
tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
|
||||||
.c_str(),
|
.c_str(),
|
||||||
rhs, rhs);
|
rhs_double, rhs_double);
|
||||||
}
|
}
|
||||||
return ::testing::AssertionSuccess();
|
return ::testing::AssertionSuccess();
|
||||||
}
|
}
|
||||||
@ -149,6 +151,10 @@ template <typename NativeT>
|
|||||||
// Specializations for floating types that do bitwise comparisons when equality
|
// Specializations for floating types that do bitwise comparisons when equality
|
||||||
// comparison is requested.
|
// comparison is requested.
|
||||||
template <>
|
template <>
|
||||||
|
::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
|
||||||
|
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
|
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
|
||||||
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
|
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
|
||||||
}
|
}
|
||||||
@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
|
|||||||
case U64:
|
case U64:
|
||||||
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
|
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
|
||||||
break;
|
break;
|
||||||
|
case BF16:
|
||||||
|
match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
|
||||||
|
break;
|
||||||
case F32:
|
case F32:
|
||||||
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
|
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
|
||||||
break;
|
break;
|
||||||
|
@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
|
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
#include <Eigen/Core>
|
#include <Eigen/Core>
|
||||||
@ -32,6 +33,8 @@ using ::tensorflow::int16;
|
|||||||
using ::tensorflow::int32;
|
using ::tensorflow::int32;
|
||||||
using ::tensorflow::int64;
|
using ::tensorflow::int64;
|
||||||
|
|
||||||
|
using ::tensorflow::bfloat16;
|
||||||
|
|
||||||
using ::tensorflow::uint8;
|
using ::tensorflow::uint8;
|
||||||
using ::tensorflow::uint16;
|
using ::tensorflow::uint16;
|
||||||
using ::tensorflow::uint32;
|
using ::tensorflow::uint32;
|
||||||
|
@ -46,6 +46,12 @@ enum PrimitiveType {
|
|||||||
// converted to f16 from f32 at arbirary points in the computation.
|
// converted to f16 from f32 at arbirary points in the computation.
|
||||||
F16 = 10;
|
F16 = 10;
|
||||||
F32 = 11;
|
F32 = 11;
|
||||||
|
|
||||||
|
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
|
||||||
|
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
|
||||||
|
// and 7 bits for the mantissa.
|
||||||
|
BF16 = 16;
|
||||||
|
|
||||||
F64 = 12;
|
F64 = 12;
|
||||||
|
|
||||||
// Complex values of fixed width.
|
// Complex values of fixed width.
|
||||||
@ -63,6 +69,8 @@ enum PrimitiveType {
|
|||||||
// An opaque type used for passing context specific data to a custom
|
// An opaque type used for passing context specific data to a custom
|
||||||
// operation.
|
// operation.
|
||||||
OPAQUE = 14;
|
OPAQUE = 14;
|
||||||
|
|
||||||
|
// Next = 17
|
||||||
}
|
}
|
||||||
|
|
||||||
// Describes the value held inside padding elements.
|
// Describes the value held inside padding elements.
|
||||||
@ -310,7 +318,10 @@ message LiteralProto {
|
|||||||
repeated double f64s = 9;
|
repeated double f64s = 9;
|
||||||
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
||||||
repeated LiteralProto tuple_literals = 10;
|
repeated LiteralProto tuple_literals = 10;
|
||||||
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
// The F16s and BF16s are encoded in little endian byte order
|
||||||
|
bytes f16s = 11;
|
||||||
|
bytes bf16s = 13;
|
||||||
|
// Next = 14
|
||||||
}
|
}
|
||||||
|
|
||||||
message WindowDimension {
|
message WindowDimension {
|
||||||
|
@ -18,17 +18,9 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
|
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
|
||||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
|
for (int64 i = 0; i < size; ++i) {
|
||||||
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
|
dst[i] = bfloat16(src[i]);
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
for (; size != 0; p += 2, q++, size--) {
|
|
||||||
*q = p[0];
|
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
for (; size != 0; p += 2, q++, size--) {
|
|
||||||
*q = p[1];
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
|
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/bfloat16.h"
|
#include "tensorflow/core/framework/bfloat16.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/casts.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
@ -27,6 +28,97 @@ TEST(Bfloat16Test, Simple) {
|
|||||||
EXPECT_EQ(0x4140, a.value);
|
EXPECT_EQ(0x4140, a.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
||||||
|
uint32_t low_mantissa) {
|
||||||
|
return bit_cast<float>((sign << 31) + (exponent << 23) +
|
||||||
|
(high_mantissa << 16) + low_mantissa);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Bfloat16TestParam {
|
||||||
|
float input;
|
||||||
|
float expected;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Bfloat16Test : public ::testing::Test,
|
||||||
|
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
||||||
|
|
||||||
|
TEST_P(Bfloat16Test, RoundOrTruncate) {
|
||||||
|
bfloat16 a(GetParam().input);
|
||||||
|
if (std::isnan(GetParam().input)) {
|
||||||
|
EXPECT_TRUE(std::isnan(float(a)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(GetParam().expected, float(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
Bfloat16Test_Instantiation, Bfloat16Test,
|
||||||
|
::testing::Values(
|
||||||
|
// More than half.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
|
||||||
|
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
|
||||||
|
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// Exact half.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// NaN stays at NaN.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
|
||||||
|
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// NaN stays at NaN -- no exponents overflow.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
|
||||||
|
BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// More than half, round to an odd number.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
|
||||||
|
BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// Less than half, truncate.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// Less than half, truncate.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// Exact at half, but result is already even.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
||||||
|
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
||||||
|
|
||||||
|
// Denormal values.
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
|
||||||
|
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
|
||||||
|
Bfloat16TestParam{
|
||||||
|
BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
|
||||||
|
BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
|
||||||
|
TEST(Bfloat16Test, RoundWithFractionOverflow) {
|
||||||
|
// Still works with fraction overflow -- round to 4./
|
||||||
|
//
|
||||||
|
// Input 3.9960938:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||||
|
// 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1100000000000000
|
||||||
|
//
|
||||||
|
// Should round to 4.0:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
|
// 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||||
|
bfloat16 a(3.9960938f);
|
||||||
|
EXPECT_EQ(4.0, float(a));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Bfloat16Test, Conversion) {
|
TEST(Bfloat16Test, Conversion) {
|
||||||
float a[100];
|
float a[100];
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
|
@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16;
|
|||||||
// see framework/bfloat16.h for description.
|
// see framework/bfloat16.h for description.
|
||||||
struct bfloat16 {
|
struct bfloat16 {
|
||||||
EIGEN_DEVICE_FUNC bfloat16() {}
|
EIGEN_DEVICE_FUNC bfloat16() {}
|
||||||
EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
|
|
||||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
|
||||||
|
uint32_t input;
|
||||||
|
memcpy(&input, &v, sizeof(uint32_t));
|
||||||
|
|
||||||
|
if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
|
||||||
|
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
|
||||||
|
// this makes sure after truncation we don't end up with an inf.
|
||||||
|
//
|
||||||
|
// qNaN magic: All exponent bits set + most significant bit of fraction
|
||||||
|
// set.
|
||||||
|
value = 0x7fc0;
|
||||||
|
} else {
|
||||||
|
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||||
|
// reduces expected error when we convert a large number of floats. Here
|
||||||
|
// is how it works:
|
||||||
|
//
|
||||||
|
// Definitions:
|
||||||
|
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
||||||
|
// with the following tags:
|
||||||
|
//
|
||||||
|
// Sign | Exp (8 bits) | Frac (23 bits)
|
||||||
|
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
||||||
|
//
|
||||||
|
// S: Sign bit.
|
||||||
|
// E: Exponent bits.
|
||||||
|
// F: First 6 bits of fraction.
|
||||||
|
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
||||||
|
// rest of the float32. This is also the 7th bit of fraction
|
||||||
|
// R: Rounding bit, 8th bit of fraction.
|
||||||
|
// T: Sticky bits, rest of fraction, 15 bits.
|
||||||
|
//
|
||||||
|
// To round half to nearest even, there are 3 cases where we want to round
|
||||||
|
// down (simply truncate the result of the bits away, which consists of
|
||||||
|
// rounding bit and sticky bits) and two cases where we want to round up
|
||||||
|
// (truncate then add one to the result).
|
||||||
|
//
|
||||||
|
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
||||||
|
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
||||||
|
// truncates the last 16 bits away.
|
||||||
|
//
|
||||||
|
// To understand how it works, we can analyze this algorithm case by case:
|
||||||
|
//
|
||||||
|
// 1. L = 0, R = 0:
|
||||||
|
// Expect: round down, this is less than half value.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||||
|
// - Adding rounding bias to input may create any carry, depending on
|
||||||
|
// whether there is any value set to 1 in T bits.
|
||||||
|
// - R may be set to 1 if there is a carry.
|
||||||
|
// - L remains 0.
|
||||||
|
// - Note that this case also handles Inf and -Inf, where all fraction
|
||||||
|
// bits, including L, R and Ts are all 0. The output remains Inf after
|
||||||
|
// this algorithm.
|
||||||
|
//
|
||||||
|
// 2. L = 1, R = 0:
|
||||||
|
// Expect: round down, this is less than half value.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||||
|
// - Adding rounding bias to input doesn't change sticky bits but
|
||||||
|
// adds 1 to rounding bit.
|
||||||
|
// - L remains 1.
|
||||||
|
//
|
||||||
|
// 3. L = 0, R = 1, all of T are 0:
|
||||||
|
// Expect: round down, this is exactly at half, the result is already
|
||||||
|
// even (L=0).
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||||
|
// - Adding rounding bias to input sets all sticky bits to 1, but
|
||||||
|
// doesn't create a carry.
|
||||||
|
// - R remains 1.
|
||||||
|
// - L remains 0.
|
||||||
|
//
|
||||||
|
// 4. L = 1, R = 1:
|
||||||
|
// Expect: round up, this is exactly at half, the result needs to be
|
||||||
|
// round to the next even number.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||||
|
// - Adding rounding bias to input doesn't change sticky bits, but
|
||||||
|
// creates a carry from rounding bit.
|
||||||
|
// - The carry sets L to 0, creates another carry bit and propagate
|
||||||
|
// forward to F bits.
|
||||||
|
// - If all the F bits are 1, a carry then propagates to the exponent
|
||||||
|
// bits, which then creates the minimum value with the next exponent
|
||||||
|
// value. Note that we won't have the case where exponents are all 1,
|
||||||
|
// since that's either a NaN (handled in the other if condition) or inf
|
||||||
|
// (handled in case 1).
|
||||||
|
//
|
||||||
|
// 5. L = 0, R = 1, any of T is 1:
|
||||||
|
// Expect: round up, this is greater than half.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||||
|
// - Adding rounding bias to input creates a carry from sticky bits,
|
||||||
|
// sets rounding bit to 0, then create another carry.
|
||||||
|
// - The second carry sets L to 1.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// Exact half value that is already even:
|
||||||
|
// Input:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||||
|
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
||||||
|
//
|
||||||
|
// This falls into case 3. We truncate the rest of 16 bits and no
|
||||||
|
// carry is created into F and L:
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L
|
||||||
|
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||||
|
//
|
||||||
|
// Exact half value, round to next even number:
|
||||||
|
// Input:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||||
|
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
||||||
|
//
|
||||||
|
// This falls into case 4. We create a carry from R and T,
|
||||||
|
// which then propagates into L and F:
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L
|
||||||
|
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Max denormal value round to min normal value:
|
||||||
|
// Input:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||||
|
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
||||||
|
//
|
||||||
|
// This falls into case 4. We create a carry from R and T,
|
||||||
|
// propagate into L and F, which then propagates into exponent
|
||||||
|
// bits:
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L
|
||||||
|
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||||
|
//
|
||||||
|
// Max normal value round to Inf:
|
||||||
|
// Input:
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||||
|
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
||||||
|
//
|
||||||
|
// This falls into case 4. We create a carry from R and T,
|
||||||
|
// propagate into L and F, which then propagates into exponent
|
||||||
|
// bits:
|
||||||
|
//
|
||||||
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
|
// S E E E E E E E E F F F F F F L
|
||||||
|
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Least significant bit of resulting bfloat.
|
||||||
|
uint32_t lsb = (input >> 16) & 1;
|
||||||
|
uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
input += rounding_bias;
|
||||||
|
value = static_cast<uint16_t>(input >> 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
||||||
|
: bfloat16(static_cast<float>(val)) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
||||||
|
float result;
|
||||||
|
|
||||||
|
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
|
||||||
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
value = p[0];
|
q[0] = value;
|
||||||
|
q[1] = 0;
|
||||||
#else
|
#else
|
||||||
value = p[1];
|
q[0] = 0;
|
||||||
|
q[1] = value;
|
||||||
#endif
|
#endif
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator bool() const {
|
||||||
|
return static_cast<bool>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
|
||||||
|
return static_cast<Eigen::half>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator short() const {
|
||||||
|
return static_cast<short>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator int() const {
|
||||||
|
return static_cast<int>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator char() const {
|
||||||
|
return static_cast<char>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator signed char() const {
|
||||||
|
return static_cast<signed char>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
|
||||||
|
return static_cast<unsigned char>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
|
||||||
|
return static_cast<unsigned int>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
|
||||||
|
return static_cast<unsigned long>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
|
||||||
|
return static_cast<unsigned long long>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator long long() const {
|
||||||
|
return static_cast<long long>(float(*this));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC explicit operator double() const {
|
||||||
|
return static_cast<double>(float(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t value;
|
uint16_t value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool operator==(const bfloat16 a, const bfloat16 b) {
|
||||||
|
return a.value == b.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator!=(const bfloat16 a, const bfloat16 b) {
|
||||||
|
return a.value != b.value;
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
template <>
|
template <>
|
||||||
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
|
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
|
using ::tensorflow::operator==;
|
||||||
const tensorflow::bfloat16 b) {
|
using ::tensorflow::operator!=;
|
||||||
return a.value == b.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
#ifdef COMPILER_MSVC
|
#ifdef COMPILER_MSVC
|
||||||
|
Loading…
x
Reference in New Issue
Block a user