Merge pull request #40962 from Intel-tensorflow:yang/eigen-bf16
PiperOrigin-RevId: 321267626 Change-Id: I62c174955a9ce3801158ebfb5aee23a40267c04d
This commit is contained in:
commit
25913db8b6
@ -896,7 +896,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) {
|
|||||||
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
} else if (element_ty.isBF16()) {
|
} else if (element_ty.isBF16()) {
|
||||||
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
|
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
|
||||||
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
} else if (element_ty.isF32()) {
|
} else if (element_ty.isF32()) {
|
||||||
|
|||||||
@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
builder,
|
builder,
|
||||||
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
|
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, static_cast<Eigen::bfloat16>(
|
||||||
|
Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
|
||||||
case F64:
|
case F64:
|
||||||
@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
Eigen::NumTraits<Eigen::half>::lowest());
|
Eigen::NumTraits<Eigen::half>::lowest());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::lowest());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
|
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
|
||||||
case F64:
|
case F64:
|
||||||
@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
std::numeric_limits<Eigen::half>::min());
|
std::numeric_limits<Eigen::half>::min());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, std::numeric_limits<Eigen::bfloat16>::min());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
|
||||||
case F64:
|
case F64:
|
||||||
@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
Eigen::NumTraits<Eigen::half>::highest());
|
Eigen::NumTraits<Eigen::half>::highest());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::highest());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
|
||||||
case F64:
|
case F64:
|
||||||
@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(
|
return ConstantR0<Eigen::half>(
|
||||||
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
|
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
|
builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder,
|
return ConstantR0<float>(builder,
|
||||||
std::numeric_limits<float>::quiet_NaN());
|
std::numeric_limits<float>::quiet_NaN());
|
||||||
|
|||||||
@ -218,23 +218,12 @@ int64 RecursiveElementCount(const Shape& shape) {
|
|||||||
// Returns whether the given value is infinity.
|
// Returns whether the given value is infinity.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
bool IsInf(NativeT val) {
|
bool IsInf(NativeT val) {
|
||||||
return std::isinf(val);
|
return Eigen::numext::isinf(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
bool IsInf<half>(half val) {
|
|
||||||
return std::isinf(static_cast<float>(val));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns whether the given value is nan.
|
// Returns whether the given value is nan.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
float IsNan(NativeT value) {
|
bool IsNan(NativeT value) {
|
||||||
return std::isnan(value);
|
return Eigen::numext::isnan(value);
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
float IsNan(half value) {
|
|
||||||
return IsNan<float>(static_cast<float>(value));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts the given floating-point value to a string.
|
// Converts the given floating-point value to a string.
|
||||||
|
|||||||
@ -455,10 +455,10 @@ int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
// NaNs sort to the end.
|
// NaNs sort to the end.
|
||||||
if (!std::isnan(x) && std::isnan(y)) {
|
if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (std::isnan(x) && !std::isnan(y)) {
|
if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@ -962,7 +962,7 @@ struct Frexp {
|
|||||||
struct Heaviside {
|
struct Heaviside {
|
||||||
bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
|
bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
|
||||||
float x = static_cast<float>(bx);
|
float x = static_cast<float>(bx);
|
||||||
if (std::isnan(x)) {
|
if (Eigen::numext::isnan(x)) {
|
||||||
return bx;
|
return bx;
|
||||||
}
|
}
|
||||||
if (x < 0) {
|
if (x < 0) {
|
||||||
@ -984,7 +984,9 @@ struct IsInf {
|
|||||||
bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
|
bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
|
||||||
};
|
};
|
||||||
struct IsNan {
|
struct IsNan {
|
||||||
bool operator()(bfloat16 a) { return std::isnan(static_cast<float>(a)); }
|
bool operator()(bfloat16 a) {
|
||||||
|
return Eigen::numext::isnan(static_cast<float>(a));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
struct Ldexp {
|
struct Ldexp {
|
||||||
bfloat16 operator()(bfloat16 a, int exp) {
|
bfloat16 operator()(bfloat16 a, int exp) {
|
||||||
@ -1200,25 +1202,25 @@ struct Ge {
|
|||||||
struct Maximum {
|
struct Maximum {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float fa(a), fb(b);
|
float fa(a), fb(b);
|
||||||
return std::isnan(fa) || fa > fb ? a : b;
|
return Eigen::numext::isnan(fa) || fa > fb ? a : b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Minimum {
|
struct Minimum {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float fa(a), fb(b);
|
float fa(a), fb(b);
|
||||||
return std::isnan(fa) || fa < fb ? a : b;
|
return Eigen::numext::isnan(fa) || fa < fb ? a : b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Fmax {
|
struct Fmax {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float fa(a), fb(b);
|
float fa(a), fb(b);
|
||||||
return std::isnan(fb) || fa > fb ? a : b;
|
return Eigen::numext::isnan(fb) || fa > fb ? a : b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Fmin {
|
struct Fmin {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float fa(a), fb(b);
|
float fa(a), fb(b);
|
||||||
return std::isnan(fb) || fa < fb ? a : b;
|
return Eigen::numext::isnan(fb) || fa < fb ? a : b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1244,7 +1246,8 @@ struct NextAfter {
|
|||||||
float from_as_float(from), to_as_float(to);
|
float from_as_float(from), to_as_float(to);
|
||||||
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
||||||
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
||||||
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
|
if (Eigen::numext::isnan(from_as_float) ||
|
||||||
|
Eigen::numext::isnan(to_as_float)) {
|
||||||
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
||||||
}
|
}
|
||||||
if (from_as_int == to_as_int) {
|
if (from_as_int == to_as_int) {
|
||||||
|
|||||||
@ -2674,7 +2674,9 @@ struct MinMaxFiniteValue<Eigen::half> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct MinMaxFiniteValue<bfloat16> {
|
struct MinMaxFiniteValue<bfloat16> {
|
||||||
static double max() { return static_cast<double>(bfloat16::highest()); }
|
static double max() {
|
||||||
|
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||||
|
}
|
||||||
static double min() { return -max(); }
|
static double min() { return -max(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
::testing::Values(
|
::testing::Values(
|
||||||
// The largest negative number smaller than zero in bf16 that's not
|
// The largest negative number smaller than zero in bf16 that's not
|
||||||
// denormalized.
|
// denormalized.
|
||||||
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
|
std::make_pair(static_cast<float>(
|
||||||
|
-std::numeric_limits<Eigen::bfloat16>::min()),
|
||||||
0.0f),
|
0.0f),
|
||||||
// Test odd and even values.
|
// Test odd and even values.
|
||||||
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
||||||
|
|||||||
@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
|
#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
// Compact 16-bit encoding of floating point numbers. This representation uses
|
// Compact 16-bit encoding of floating point numbers. This representation uses
|
||||||
|
|||||||
@ -35,14 +35,16 @@ TEST(Bfloat16Test, FlushDenormalsToZero) {
|
|||||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
||||||
denorm < std::numeric_limits<float>::denorm_min();
|
denorm < std::numeric_limits<float>::denorm_min();
|
||||||
denorm = std::nextafterf(denorm, 1.0f)) {
|
denorm = std::nextafterf(denorm, 1.0f)) {
|
||||||
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
|
bfloat16 bf_trunc =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16(denorm));
|
||||||
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
|
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
|
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
|
||||||
} else {
|
} else {
|
||||||
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
|
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
|
||||||
}
|
}
|
||||||
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
|
bfloat16 bf_round =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm));
|
||||||
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
|
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
|
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
|
||||||
@ -88,7 +90,8 @@ class Bfloat16Test : public ::testing::Test,
|
|||||||
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
||||||
|
|
||||||
TEST_P(Bfloat16Test, TruncateTest) {
|
TEST_P(Bfloat16Test, TruncateTest) {
|
||||||
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
|
bfloat16 truncated =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16((GetParam().input)));
|
||||||
|
|
||||||
if (std::isnan(GetParam().input)) {
|
if (std::isnan(GetParam().input)) {
|
||||||
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
||||||
@ -97,7 +100,8 @@ TEST_P(Bfloat16Test, TruncateTest) {
|
|||||||
|
|
||||||
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
||||||
|
|
||||||
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
|
bfloat16 rounded = bfloat16(
|
||||||
|
Eigen::bfloat16_impl::float_to_bfloat16_rtne((GetParam().input)));
|
||||||
if (std::isnan(GetParam().input)) {
|
if (std::isnan(GetParam().input)) {
|
||||||
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
||||||
return;
|
return;
|
||||||
@ -172,9 +176,13 @@ TEST(Bfloat16Test, Conversion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Bfloat16Test, Epsilon) {
|
TEST(Bfloat16Test, Epsilon) {
|
||||||
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
|
EXPECT_LT(1.0f,
|
||||||
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
|
static_cast<float>(Eigen::NumTraits<Eigen::bfloat16>::epsilon() +
|
||||||
bfloat16(1.0f)));
|
bfloat16(1.0f)));
|
||||||
|
EXPECT_EQ(1.0f,
|
||||||
|
static_cast<float>((Eigen::NumTraits<Eigen::bfloat16>::epsilon() /
|
||||||
|
bfloat16(2.0f)) +
|
||||||
|
bfloat16(1.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Bfloat16Test, Negate) {
|
TEST(Bfloat16Test, Negate) {
|
||||||
|
|||||||
@ -43,47 +43,17 @@ typedef Eigen::QUInt16 quint16;
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
|
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||||
reinterpret_cast<uint16_t*>(&float_val));
|
reinterpret_cast<uint16_t*>(&float_val));
|
||||||
#else
|
#else
|
||||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||||
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
|
|
||||||
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
|
|
||||||
template <>
|
|
||||||
struct NumTraits<tensorflow::bfloat16>
|
|
||||||
: GenericNumTraits<tensorflow::bfloat16> {
|
|
||||||
enum {
|
|
||||||
IsInteger = 0,
|
|
||||||
IsSigned = 1,
|
|
||||||
RequireInitialization = 0
|
|
||||||
};
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::highest());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::lowest());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::infinity());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::quiet_NaN());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
||||||
enum {
|
enum {
|
||||||
@ -104,30 +74,6 @@ struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
|||||||
static inline tensorflow::tstring quiet_NaN();
|
static inline tensorflow::tstring quiet_NaN();
|
||||||
};
|
};
|
||||||
|
|
||||||
using ::tensorflow::operator==;
|
|
||||||
using ::tensorflow::operator!=;
|
|
||||||
|
|
||||||
namespace numext {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace numext
|
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
@ -138,6 +84,13 @@ struct hash<Eigen::half> {
|
|||||||
return static_cast<std::size_t>(a.x);
|
return static_cast<std::size_t>(a.x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct hash<Eigen::bfloat16> {
|
||||||
|
std::size_t operator()(const Eigen::bfloat16& a) const {
|
||||||
|
return hash<float>()(static_cast<float>(a));
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace std
|
} // namespace std
|
||||||
#endif // _MSC_VER
|
#endif // _MSC_VER
|
||||||
|
|
||||||
|
|||||||
@ -278,48 +278,6 @@ template <typename From, typename To>
|
|||||||
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
|
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
|
||||||
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
|
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
|
||||||
|
|
||||||
// Specialized cast op impls for bfloat16.
|
|
||||||
template <>
|
|
||||||
struct scalar_cast_op<::tensorflow::bfloat16, float> {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
|
||||||
typedef float result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
|
|
||||||
const ::tensorflow::bfloat16& a) const {
|
|
||||||
float ret;
|
|
||||||
uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
p[0] = a.value;
|
|
||||||
p[1] = 0;
|
|
||||||
#else
|
|
||||||
static_assert(::tensorflow::port::kLittleEndian,
|
|
||||||
"Not a little endian system!");
|
|
||||||
p[0] = 0;
|
|
||||||
p[1] = a.value;
|
|
||||||
#endif
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
|
|
||||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct scalar_cast_op<float, ::tensorflow::bfloat16> {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
|
||||||
typedef ::tensorflow::bfloat16 result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
|
|
||||||
const float a) const {
|
|
||||||
return ::tensorflow::bfloat16(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
|
|
||||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
|
|||||||
@ -165,7 +165,7 @@ bool IsZero(T v);
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
ALWAYS_INLINE bool IsZero(bfloat16 v) {
|
ALWAYS_INLINE bool IsZero(bfloat16 v) {
|
||||||
return v.IsZero();
|
return !static_cast<bool>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|||||||
@ -14,9 +14,6 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
// clang-format off
|
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
|
||||||
// clang-format on
|
|
||||||
#include "tensorflow/core/kernels/training_ops.h"
|
#include "tensorflow/core/kernels/training_ops.h"
|
||||||
|
|
||||||
#include <algorithm> // NOLINT
|
#include <algorithm> // NOLINT
|
||||||
@ -26,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
#include "tensorflow/core/kernels/variable_ops.h"
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
|
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ package(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bfloat16",
|
name = "bfloat16",
|
||||||
srcs = ["bfloat16.cc"],
|
|
||||||
hdrs = ["bfloat16.h"],
|
hdrs = ["bfloat16.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:byte_order",
|
"//tensorflow/core/platform:byte_order",
|
||||||
@ -24,7 +23,6 @@ cc_library(
|
|||||||
filegroup(
|
filegroup(
|
||||||
name = "mobile_srcs_no_runtime",
|
name = "mobile_srcs_no_runtime",
|
||||||
srcs = [
|
srcs = [
|
||||||
"bfloat16.cc",
|
|
||||||
"bfloat16.h",
|
"bfloat16.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
const uint16_t bfloat16::NAN_VALUE;
|
|
||||||
const uint16_t bfloat16::ZERO_VALUE;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
|
|
||||||
return static_cast<Eigen::half>(float(*this));
|
|
||||||
}
|
|
||||||
} // end namespace tensorflow
|
|
||||||
@ -16,520 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||||
#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||||
|
|
||||||
#include <cmath>
|
// clang-format off
|
||||||
#include <complex>
|
|
||||||
#include <iostream>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
#include "tensorflow/core/platform/byte_order.h"
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#if defined(__CUDACC__) || (defined(__HIPCC__) && defined(__HIP__))
|
// clang-format on
|
||||||
// All functions callable from CUDA code must be qualified with __device__
|
|
||||||
#define B16_DEVICE_FUNC __host__ __device__
|
|
||||||
|
|
||||||
#else
|
|
||||||
#define B16_DEVICE_FUNC
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace Eigen {
|
|
||||||
struct half;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
typedef Eigen::bfloat16 bfloat16;
|
||||||
// Single precision complex.
|
|
||||||
typedef std::complex<float> complex64;
|
|
||||||
// Double precision complex.
|
|
||||||
typedef std::complex<double> complex128;
|
|
||||||
|
|
||||||
// see framework/bfloat16.h for description.
|
|
||||||
struct bfloat16 {
|
|
||||||
// The default constructor must yield a zero value, not an uninitialized
|
|
||||||
// value; some TF kernels use T() as a zero value.
|
|
||||||
B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
|
|
||||||
bfloat16 output;
|
|
||||||
if (float_isnan(v)) {
|
|
||||||
output.value = NAN_VALUE;
|
|
||||||
return output;
|
|
||||||
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
|
|
||||||
// Flush denormal to +/- 0.
|
|
||||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
output.value = p[0];
|
|
||||||
#else
|
|
||||||
output.value = p[1];
|
|
||||||
#endif
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const float v) {
|
|
||||||
value = round_to_bfloat16(v).value;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const double val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
// Following the convention of numpy, converting between complex and
|
|
||||||
// float will lead to loss of imag value.
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
|
|
||||||
: bfloat16(val.real()) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
|
|
||||||
: bfloat16(static_cast<float>(val.real())) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const int val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const long val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const long long val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const T& val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator float() const {
|
|
||||||
float result = 0;
|
|
||||||
|
|
||||||
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
|
|
||||||
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
q[0] = value;
|
|
||||||
#else
|
|
||||||
q[1] = value;
|
|
||||||
#endif
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator bool() const {
|
|
||||||
return static_cast<bool>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator Eigen::half() const;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator short() const {
|
|
||||||
return static_cast<short>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator int() const {
|
|
||||||
return static_cast<int>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator long() const {
|
|
||||||
return static_cast<long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator char() const {
|
|
||||||
return static_cast<char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator signed char() const {
|
|
||||||
return static_cast<signed char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned char() const {
|
|
||||||
return static_cast<unsigned char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned short() const {
|
|
||||||
return static_cast<unsigned short>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned int() const {
|
|
||||||
return static_cast<unsigned int>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned long() const {
|
|
||||||
return static_cast<unsigned long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned long long() const {
|
|
||||||
return static_cast<unsigned long long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator long long() const {
|
|
||||||
return static_cast<long long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator double() const {
|
|
||||||
return static_cast<double>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator complex64() const {
|
|
||||||
return complex64(float(*this), float(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator complex128() const {
|
|
||||||
return complex128(double(*this), double(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
union FP32 {
|
|
||||||
unsigned int u;
|
|
||||||
float f;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
|
|
||||||
// method.
|
|
||||||
// TODO: There is a slightly faster implementation (8% faster on CPU)
|
|
||||||
// than this (documented in cl/175987786), that is exponentially harder to
|
|
||||||
// understand and document. Switch to the faster version when converting to
|
|
||||||
// BF16 becomes compute-bound.
|
|
||||||
B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
|
|
||||||
uint32_t input;
|
|
||||||
FP32 f;
|
|
||||||
f.f = v;
|
|
||||||
input = f.u;
|
|
||||||
bfloat16 output;
|
|
||||||
|
|
||||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
|
||||||
// reduces expected error when we convert a large number of floats. Here
|
|
||||||
// is how it works:
|
|
||||||
//
|
|
||||||
// Definitions:
|
|
||||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
|
||||||
// with the following tags:
|
|
||||||
//
|
|
||||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
|
||||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
|
||||||
//
|
|
||||||
// S: Sign bit.
|
|
||||||
// E: Exponent bits.
|
|
||||||
// F: First 6 bits of fraction.
|
|
||||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
|
||||||
// rest of the float32. This is also the 7th bit of fraction
|
|
||||||
// R: Rounding bit, 8th bit of fraction.
|
|
||||||
// T: Sticky bits, rest of fraction, 15 bits.
|
|
||||||
//
|
|
||||||
// To round half to nearest even, there are 3 cases where we want to round
|
|
||||||
// down (simply truncate the result of the bits away, which consists of
|
|
||||||
// rounding bit and sticky bits) and two cases where we want to round up
|
|
||||||
// (truncate then add one to the result).
|
|
||||||
//
|
|
||||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
|
||||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
|
||||||
// truncates the last 16 bits away.
|
|
||||||
//
|
|
||||||
// To understand how it works, we can analyze this algorithm case by case:
|
|
||||||
//
|
|
||||||
// 1. L = 0, R = 0:
|
|
||||||
// Expect: round down, this is less than half value.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input may create any carry, depending on
|
|
||||||
// whether there is any value set to 1 in T bits.
|
|
||||||
// - R may be set to 1 if there is a carry.
|
|
||||||
// - L remains 0.
|
|
||||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
|
||||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
|
||||||
// this algorithm.
|
|
||||||
//
|
|
||||||
// 2. L = 1, R = 0:
|
|
||||||
// Expect: round down, this is less than half value.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
|
||||||
// - Adding rounding bias to input doesn't change sticky bits but
|
|
||||||
// adds 1 to rounding bit.
|
|
||||||
// - L remains 1.
|
|
||||||
//
|
|
||||||
// 3. L = 0, R = 1, all of T are 0:
|
|
||||||
// Expect: round down, this is exactly at half, the result is already
|
|
||||||
// even (L=0).
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
|
||||||
// doesn't create a carry.
|
|
||||||
// - R remains 1.
|
|
||||||
// - L remains 0.
|
|
||||||
//
|
|
||||||
// 4. L = 1, R = 1:
|
|
||||||
// Expect: round up, this is exactly at half, the result needs to be
|
|
||||||
// round to the next even number.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
|
||||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
|
||||||
// creates a carry from rounding bit.
|
|
||||||
// - The carry sets L to 0, creates another carry bit and propagate
|
|
||||||
// forward to F bits.
|
|
||||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
|
||||||
// bits, which then creates the minimum value with the next exponent
|
|
||||||
// value. Note that we won't have the case where exponents are all 1,
|
|
||||||
// since that's either a NaN (handled in the other if condition) or inf
|
|
||||||
// (handled in case 1).
|
|
||||||
//
|
|
||||||
// 5. L = 0, R = 1, any of T is 1:
|
|
||||||
// Expect: round up, this is greater than half.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
|
||||||
// sets rounding bit to 0, then create another carry.
|
|
||||||
// - The second carry sets L to 1.
|
|
||||||
//
|
|
||||||
// Examples:
|
|
||||||
//
|
|
||||||
// Exact half value that is already even:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
|
||||||
//
|
|
||||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
|
||||||
// carry is created into F and L:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
|
||||||
//
|
|
||||||
// Exact half value, round to next even number:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// which then propagates into L and F:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// Max denormal value round to min normal value:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// propagate into L and F, which then propagates into exponent
|
|
||||||
// bits:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
|
||||||
//
|
|
||||||
// Max normal value round to Inf:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// propagate into L and F, which then propagates into exponent
|
|
||||||
// bits:
|
|
||||||
//
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// Least significant bit of resulting bfloat.
|
|
||||||
uint32_t lsb = (input >> 16) & 1;
|
|
||||||
uint32_t rounding_bias = 0x7fff + lsb;
|
|
||||||
input += rounding_bias;
|
|
||||||
output.value = static_cast<uint16_t>(input >> 16);
|
|
||||||
if ((f.u & 0xff800000u) == 0) {
|
|
||||||
// Flush positive denormal to 0
|
|
||||||
output.value = 0x0;
|
|
||||||
}
|
|
||||||
if ((f.u & 0xff800000u) == 0x80000000u) {
|
|
||||||
// Flush negative denormal to -0
|
|
||||||
output.value = 0x8000;
|
|
||||||
}
|
|
||||||
if (float_isnan(v)) {
|
|
||||||
output.value = NAN_VALUE;
|
|
||||||
}
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 epsilon() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x3c00; // 0x1.0p-7
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 highest() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x7F7F; // 0x1.FEp127
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 lowest() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0xFF7F; // -0x1.FEp127
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 min_positive_normal() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x0080; // 0x1p-126
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; }
|
|
||||||
|
|
||||||
uint16_t value;
|
|
||||||
|
|
||||||
// A value that represents "not a number".
|
|
||||||
static constexpr uint16_t NAN_VALUE = 0x7FC0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// A value that represents "zero".
|
|
||||||
static constexpr uint16_t ZERO_VALUE = 0;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC static bool float_isnan(const float& x) {
|
|
||||||
#ifdef __CUDA_ARCH__
|
|
||||||
return ::isnan(x);
|
|
||||||
#else
|
|
||||||
return std::isnan(x);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
|
|
||||||
const bfloat16& dt) {
|
|
||||||
os << static_cast<float>(dt);
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
|
|
||||||
a.value ^= 0x8000;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) < static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) <= static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) == static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) != static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) > static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) >= static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a + b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a - b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
|
|
||||||
a += bfloat16(1);
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
|
|
||||||
a -= bfloat16(1);
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
|
|
||||||
bfloat16 original_value = a;
|
|
||||||
++a;
|
|
||||||
return original_value;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
|
|
||||||
bfloat16 original_value = a;
|
|
||||||
--a;
|
|
||||||
return original_value;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a * b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a / b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
namespace std {
|
|
||||||
template <>
|
|
||||||
struct hash<tensorflow::bfloat16> {
|
|
||||||
size_t operator()(const tensorflow::bfloat16& v) const {
|
|
||||||
return hash<float>()(static_cast<float>(v));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using tensorflow::bfloat16;
|
|
||||||
inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
|
|
||||||
inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); }
|
|
||||||
inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); }
|
|
||||||
inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); }
|
|
||||||
inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); }
|
|
||||||
inline bfloat16 expm1(const bfloat16& a) {
|
|
||||||
return bfloat16(std::expm1(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); }
|
|
||||||
inline bfloat16 log1p(const bfloat16& a) {
|
|
||||||
return bfloat16(std::log1p(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 log10(const bfloat16& a) {
|
|
||||||
return bfloat16(std::log10(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 sqrt(const bfloat16& a) {
|
|
||||||
return bfloat16(std::sqrt(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
|
||||||
return bfloat16(std::pow(float(a), float(b)));
|
|
||||||
}
|
|
||||||
inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); }
|
|
||||||
inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); }
|
|
||||||
inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); }
|
|
||||||
inline bfloat16 tanh(const bfloat16& a) {
|
|
||||||
return bfloat16(std::tanh(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 floor(const bfloat16& a) {
|
|
||||||
return bfloat16(std::floor(float(a)));
|
|
||||||
}
|
|
||||||
inline bfloat16 ceil(const bfloat16& a) {
|
|
||||||
return bfloat16(std::ceil(float(a)));
|
|
||||||
}
|
|
||||||
} // namespace std
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||||
|
|||||||
@ -1446,9 +1446,11 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto size = (std::is_integral<T>::value
|
auto size = (std::is_integral<T>::value
|
||||||
? ((std::abs(limit - start) + std::abs(delta) - T(1)) /
|
? ((Eigen::numext::abs(limit - start) +
|
||||||
std::abs(delta))
|
Eigen::numext::abs(delta) - T(1)) /
|
||||||
: (std::ceil(std::abs((limit - start) / delta))));
|
Eigen::numext::abs(delta))
|
||||||
|
: (Eigen::numext::ceil(
|
||||||
|
Eigen::numext::abs((limit - start) / delta))));
|
||||||
c->set_output(0, c->Vector(static_cast<int64>(size)));
|
c->set_output(0, c->Vector(static_cast<int64>(size)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -426,10 +426,10 @@ int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
// NaNs sort to the end.
|
// NaNs sort to the end.
|
||||||
if (!std::isnan(x) && std::isnan(y)) {
|
if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (std::isnan(x) && !std::isnan(y)) {
|
if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
20
third_party/eigen3/gpu_packet_math.patch
vendored
20
third_party/eigen3/gpu_packet_math.patch
vendored
@ -22,3 +22,23 @@
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
--- a/Eigen/src/Core/arch/Default/BFloat16.h
|
||||||
|
+++ a/Eigen/src/Core/arch/Default/BFloat16.h
|
||||||
|
@@ -291,7 +291,7 @@
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||||
|
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
|
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
||||||
|
output.value = p[0];
|
||||||
|
#else
|
||||||
|
output.value = p[1];
|
||||||
|
@@ -493,7 +493,7 @@
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
|
||||||
|
float result = 0;
|
||||||
|
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
|
||||||
|
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
|
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
||||||
|
q[0] = h.value;
|
||||||
|
#else
|
||||||
|
q[1] = h.value;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user