Merge pull request #40962 from Intel-tensorflow:yang/eigen-bf16
PiperOrigin-RevId: 321267626 Change-Id: I62c174955a9ce3801158ebfb5aee23a40267c04d
This commit is contained in:
commit
25913db8b6
tensorflow
compiler
mlir/xla/transforms
xla
core
framework
kernels
lib/bfloat16
ops
python/lib/core
third_party/eigen3
@ -896,7 +896,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) {
|
||||
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (element_ty.isBF16()) {
|
||||
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
|
||||
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
|
||||
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (element_ty.isF32()) {
|
||||
|
@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
|
||||
builder,
|
||||
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
|
||||
case BF16:
|
||||
return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
|
||||
return ConstantR0<Eigen::bfloat16>(
|
||||
builder, static_cast<Eigen::bfloat16>(
|
||||
Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
|
||||
case F32:
|
||||
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
|
||||
case F64:
|
||||
@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantR0<Eigen::half>(builder,
|
||||
Eigen::NumTraits<Eigen::half>::lowest());
|
||||
case BF16:
|
||||
return ConstantR0<bfloat16>(builder, bfloat16::lowest());
|
||||
return ConstantR0<Eigen::bfloat16>(
|
||||
builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
|
||||
case F32:
|
||||
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
|
||||
case F64:
|
||||
@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantR0<Eigen::half>(builder,
|
||||
std::numeric_limits<Eigen::half>::min());
|
||||
case BF16:
|
||||
return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
|
||||
return ConstantR0<Eigen::bfloat16>(
|
||||
builder, std::numeric_limits<Eigen::bfloat16>::min());
|
||||
case F32:
|
||||
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
|
||||
case F64:
|
||||
@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantR0<Eigen::half>(builder,
|
||||
Eigen::NumTraits<Eigen::half>::highest());
|
||||
case BF16:
|
||||
return ConstantR0<bfloat16>(builder, bfloat16::highest());
|
||||
return ConstantR0<Eigen::bfloat16>(
|
||||
builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||
case F32:
|
||||
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
|
||||
case F64:
|
||||
@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
|
||||
return ConstantR0<Eigen::half>(
|
||||
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
|
||||
case BF16:
|
||||
return ConstantR0<bfloat16>(
|
||||
builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
|
||||
return ConstantR0<Eigen::bfloat16>(
|
||||
builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
|
||||
case F32:
|
||||
return ConstantR0<float>(builder,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
|
@ -218,23 +218,12 @@ int64 RecursiveElementCount(const Shape& shape) {
|
||||
// Returns whether the given value is infinity.
|
||||
template <typename NativeT>
|
||||
bool IsInf(NativeT val) {
|
||||
return std::isinf(val);
|
||||
return Eigen::numext::isinf(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsInf<half>(half val) {
|
||||
return std::isinf(static_cast<float>(val));
|
||||
}
|
||||
|
||||
// Returns whether the given value is nan.
|
||||
template <typename NativeT>
|
||||
float IsNan(NativeT value) {
|
||||
return std::isnan(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float IsNan(half value) {
|
||||
return IsNan<float>(static_cast<float>(value));
|
||||
bool IsNan(NativeT value) {
|
||||
return Eigen::numext::isnan(value);
|
||||
}
|
||||
|
||||
// Converts the given floating-point value to a string.
|
||||
|
@ -455,10 +455,10 @@ int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
||||
return 1;
|
||||
}
|
||||
// NaNs sort to the end.
|
||||
if (!std::isnan(x) && std::isnan(y)) {
|
||||
if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
|
||||
return -1;
|
||||
}
|
||||
if (std::isnan(x) && !std::isnan(y)) {
|
||||
if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
@ -962,7 +962,7 @@ struct Frexp {
|
||||
struct Heaviside {
|
||||
bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
|
||||
float x = static_cast<float>(bx);
|
||||
if (std::isnan(x)) {
|
||||
if (Eigen::numext::isnan(x)) {
|
||||
return bx;
|
||||
}
|
||||
if (x < 0) {
|
||||
@ -984,7 +984,9 @@ struct IsInf {
|
||||
bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
|
||||
};
|
||||
struct IsNan {
|
||||
bool operator()(bfloat16 a) { return std::isnan(static_cast<float>(a)); }
|
||||
bool operator()(bfloat16 a) {
|
||||
return Eigen::numext::isnan(static_cast<float>(a));
|
||||
}
|
||||
};
|
||||
struct Ldexp {
|
||||
bfloat16 operator()(bfloat16 a, int exp) {
|
||||
@ -1200,25 +1202,25 @@ struct Ge {
|
||||
struct Maximum {
|
||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||
float fa(a), fb(b);
|
||||
return std::isnan(fa) || fa > fb ? a : b;
|
||||
return Eigen::numext::isnan(fa) || fa > fb ? a : b;
|
||||
}
|
||||
};
|
||||
struct Minimum {
|
||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||
float fa(a), fb(b);
|
||||
return std::isnan(fa) || fa < fb ? a : b;
|
||||
return Eigen::numext::isnan(fa) || fa < fb ? a : b;
|
||||
}
|
||||
};
|
||||
struct Fmax {
|
||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||
float fa(a), fb(b);
|
||||
return std::isnan(fb) || fa > fb ? a : b;
|
||||
return Eigen::numext::isnan(fb) || fa > fb ? a : b;
|
||||
}
|
||||
};
|
||||
struct Fmin {
|
||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||
float fa(a), fb(b);
|
||||
return std::isnan(fb) || fa < fb ? a : b;
|
||||
return Eigen::numext::isnan(fb) || fa < fb ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1244,7 +1246,8 @@ struct NextAfter {
|
||||
float from_as_float(from), to_as_float(to);
|
||||
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
||||
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
||||
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
|
||||
if (Eigen::numext::isnan(from_as_float) ||
|
||||
Eigen::numext::isnan(to_as_float)) {
|
||||
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
||||
}
|
||||
if (from_as_int == to_as_int) {
|
||||
|
@ -2674,7 +2674,9 @@ struct MinMaxFiniteValue<Eigen::half> {
|
||||
|
||||
template <>
|
||||
struct MinMaxFiniteValue<bfloat16> {
|
||||
static double max() { return static_cast<double>(bfloat16::highest()); }
|
||||
static double max() {
|
||||
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||
}
|
||||
static double min() { return -max(); }
|
||||
};
|
||||
|
||||
|
@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(
|
||||
// The largest negative number smaller than zero in bf16 that's not
|
||||
// denormalized.
|
||||
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
|
||||
std::make_pair(static_cast<float>(
|
||||
-std::numeric_limits<Eigen::bfloat16>::min()),
|
||||
0.0f),
|
||||
// Test odd and even values.
|
||||
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
|
||||
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Compact 16-bit encoding of floating point numbers. This representation uses
|
||||
|
@ -35,14 +35,16 @@ TEST(Bfloat16Test, FlushDenormalsToZero) {
|
||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
||||
denorm < std::numeric_limits<float>::denorm_min();
|
||||
denorm = std::nextafterf(denorm, 1.0f)) {
|
||||
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
|
||||
bfloat16 bf_trunc =
|
||||
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16(denorm));
|
||||
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
|
||||
if (std::signbit(denorm)) {
|
||||
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
|
||||
} else {
|
||||
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
|
||||
}
|
||||
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
|
||||
bfloat16 bf_round =
|
||||
bfloat16(Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm));
|
||||
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
|
||||
if (std::signbit(denorm)) {
|
||||
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
|
||||
@ -88,7 +90,8 @@ class Bfloat16Test : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
||||
|
||||
TEST_P(Bfloat16Test, TruncateTest) {
|
||||
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
|
||||
bfloat16 truncated =
|
||||
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16((GetParam().input)));
|
||||
|
||||
if (std::isnan(GetParam().input)) {
|
||||
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
||||
@ -97,7 +100,8 @@ TEST_P(Bfloat16Test, TruncateTest) {
|
||||
|
||||
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
||||
|
||||
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
|
||||
bfloat16 rounded = bfloat16(
|
||||
Eigen::bfloat16_impl::float_to_bfloat16_rtne((GetParam().input)));
|
||||
if (std::isnan(GetParam().input)) {
|
||||
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
||||
return;
|
||||
@ -172,9 +176,13 @@ TEST(Bfloat16Test, Conversion) {
|
||||
}
|
||||
|
||||
TEST(Bfloat16Test, Epsilon) {
|
||||
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
|
||||
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
|
||||
bfloat16(1.0f)));
|
||||
EXPECT_LT(1.0f,
|
||||
static_cast<float>(Eigen::NumTraits<Eigen::bfloat16>::epsilon() +
|
||||
bfloat16(1.0f)));
|
||||
EXPECT_EQ(1.0f,
|
||||
static_cast<float>((Eigen::NumTraits<Eigen::bfloat16>::epsilon() /
|
||||
bfloat16(2.0f)) +
|
||||
bfloat16(1.0f)));
|
||||
}
|
||||
|
||||
TEST(Bfloat16Test, Negate) {
|
||||
|
@ -43,47 +43,17 @@ typedef Eigen::QUInt16 quint16;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
|
||||
|
||||
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||
reinterpret_cast<uint16_t*>(&float_val));
|
||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||
reinterpret_cast<uint16_t*>(&float_val));
|
||||
#else
|
||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
namespace Eigen {
|
||||
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
|
||||
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
|
||||
template <>
|
||||
struct NumTraits<tensorflow::bfloat16>
|
||||
: GenericNumTraits<tensorflow::bfloat16> {
|
||||
enum {
|
||||
IsInteger = 0,
|
||||
IsSigned = 1,
|
||||
RequireInitialization = 0
|
||||
};
|
||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
|
||||
return FloatToBFloat16(NumTraits<float>::highest());
|
||||
}
|
||||
|
||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
|
||||
return FloatToBFloat16(NumTraits<float>::lowest());
|
||||
}
|
||||
|
||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
|
||||
return FloatToBFloat16(NumTraits<float>::infinity());
|
||||
}
|
||||
|
||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
|
||||
return FloatToBFloat16(NumTraits<float>::quiet_NaN());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
||||
enum {
|
||||
@ -104,30 +74,6 @@ struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
||||
static inline tensorflow::tstring quiet_NaN();
|
||||
};
|
||||
|
||||
using ::tensorflow::operator==;
|
||||
using ::tensorflow::operator!=;
|
||||
|
||||
namespace numext {
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
|
||||
const tensorflow::bfloat16& x) {
|
||||
return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
|
||||
const tensorflow::bfloat16& x) {
|
||||
return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
|
||||
const tensorflow::bfloat16& x) {
|
||||
return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
|
||||
}
|
||||
|
||||
} // namespace numext
|
||||
} // namespace Eigen
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
@ -138,6 +84,13 @@ struct hash<Eigen::half> {
|
||||
return static_cast<std::size_t>(a.x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<Eigen::bfloat16> {
|
||||
std::size_t operator()(const Eigen::bfloat16& a) const {
|
||||
return hash<float>()(static_cast<float>(a));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
#endif // _MSC_VER
|
||||
|
||||
|
@ -278,48 +278,6 @@ template <typename From, typename To>
|
||||
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
|
||||
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
|
||||
|
||||
// Specialized cast op impls for bfloat16.
|
||||
template <>
|
||||
struct scalar_cast_op<::tensorflow::bfloat16, float> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||
typedef float result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
|
||||
const ::tensorflow::bfloat16& a) const {
|
||||
float ret;
|
||||
uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
p[0] = a.value;
|
||||
p[1] = 0;
|
||||
#else
|
||||
static_assert(::tensorflow::port::kLittleEndian,
|
||||
"Not a little endian system!");
|
||||
p[0] = 0;
|
||||
p[1] = a.value;
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
|
||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_cast_op<float, ::tensorflow::bfloat16> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||
typedef ::tensorflow::bfloat16 result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
|
||||
const float a) const {
|
||||
return ::tensorflow::bfloat16(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
|
||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
|
@ -165,7 +165,7 @@ bool IsZero(T v);
|
||||
|
||||
template <>
|
||||
ALWAYS_INLINE bool IsZero(bfloat16 v) {
|
||||
return v.IsZero();
|
||||
return !static_cast<bool>(v);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -14,9 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
// clang-format off
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
// clang-format on
|
||||
#include "tensorflow/core/kernels/training_ops.h"
|
||||
|
||||
#include <algorithm> // NOLINT
|
||||
@ -26,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
|
@ -12,7 +12,6 @@ package(
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16",
|
||||
srcs = ["bfloat16.cc"],
|
||||
hdrs = ["bfloat16.h"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:byte_order",
|
||||
@ -24,7 +23,6 @@ cc_library(
|
||||
filegroup(
|
||||
name = "mobile_srcs_no_runtime",
|
||||
srcs = [
|
||||
"bfloat16.cc",
|
||||
"bfloat16.h",
|
||||
],
|
||||
)
|
||||
|
@ -1,28 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const uint16_t bfloat16::NAN_VALUE;
|
||||
const uint16_t bfloat16::ZERO_VALUE;
|
||||
|
||||
B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
|
||||
return static_cast<Eigen::half>(float(*this));
|
||||
}
|
||||
} // end namespace tensorflow
|
@ -16,520 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||
#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
|
||||
#if defined(__CUDACC__) || (defined(__HIPCC__) && defined(__HIP__))
|
||||
// All functions callable from CUDA code must be qualified with __device__
|
||||
#define B16_DEVICE_FUNC __host__ __device__
|
||||
|
||||
#else
|
||||
#define B16_DEVICE_FUNC
|
||||
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
struct half;
|
||||
}
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
// clang-format on
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Single precision complex.
|
||||
typedef std::complex<float> complex64;
|
||||
// Double precision complex.
|
||||
typedef std::complex<double> complex128;
|
||||
|
||||
// see framework/bfloat16.h for description.
|
||||
struct bfloat16 {
|
||||
// The default constructor must yield a zero value, not an uninitialized
|
||||
// value; some TF kernels use T() as a zero value.
|
||||
B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
|
||||
|
||||
B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
|
||||
bfloat16 output;
|
||||
if (float_isnan(v)) {
|
||||
output.value = NAN_VALUE;
|
||||
return output;
|
||||
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
|
||||
// Flush denormal to +/- 0.
|
||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
||||
return output;
|
||||
}
|
||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
output.value = p[0];
|
||||
#else
|
||||
output.value = p[1];
|
||||
#endif
|
||||
return output;
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const float v) {
|
||||
value = round_to_bfloat16(v).value;
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const double val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
// Following the convention of numpy, converting between complex and
|
||||
// float will lead to loss of imag value.
|
||||
B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
|
||||
: bfloat16(val.real()) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
|
||||
: bfloat16(static_cast<float>(val.real())) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const int val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const long val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit bfloat16(const long long val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
template <class T>
|
||||
B16_DEVICE_FUNC explicit bfloat16(const T& val)
|
||||
: bfloat16(static_cast<float>(val)) {}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator float() const {
|
||||
float result = 0;
|
||||
|
||||
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
q[0] = value;
|
||||
#else
|
||||
q[1] = value;
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator bool() const {
|
||||
return static_cast<bool>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator Eigen::half() const;
|
||||
|
||||
B16_DEVICE_FUNC explicit operator short() const {
|
||||
return static_cast<short>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator int() const {
|
||||
return static_cast<int>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator long() const {
|
||||
return static_cast<long>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator char() const {
|
||||
return static_cast<char>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator signed char() const {
|
||||
return static_cast<signed char>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator unsigned char() const {
|
||||
return static_cast<unsigned char>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator unsigned short() const {
|
||||
return static_cast<unsigned short>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator unsigned int() const {
|
||||
return static_cast<unsigned int>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator unsigned long() const {
|
||||
return static_cast<unsigned long>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator unsigned long long() const {
|
||||
return static_cast<unsigned long long>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator long long() const {
|
||||
return static_cast<long long>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator double() const {
|
||||
return static_cast<double>(float(*this));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator complex64() const {
|
||||
return complex64(float(*this), float(0.0));
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC explicit operator complex128() const {
|
||||
return complex128(double(*this), double(0.0));
|
||||
}
|
||||
|
||||
union FP32 {
|
||||
unsigned int u;
|
||||
float f;
|
||||
};
|
||||
|
||||
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
|
||||
// method.
|
||||
// TODO: There is a slightly faster implementation (8% faster on CPU)
|
||||
// than this (documented in cl/175987786), that is exponentially harder to
|
||||
// understand and document. Switch to the faster version when converting to
|
||||
// BF16 becomes compute-bound.
|
||||
B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
|
||||
uint32_t input;
|
||||
FP32 f;
|
||||
f.f = v;
|
||||
input = f.u;
|
||||
bfloat16 output;
|
||||
|
||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||
// reduces expected error when we convert a large number of floats. Here
|
||||
// is how it works:
|
||||
//
|
||||
// Definitions:
|
||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
||||
// with the following tags:
|
||||
//
|
||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
||||
//
|
||||
// S: Sign bit.
|
||||
// E: Exponent bits.
|
||||
// F: First 6 bits of fraction.
|
||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
||||
// rest of the float32. This is also the 7th bit of fraction
|
||||
// R: Rounding bit, 8th bit of fraction.
|
||||
// T: Sticky bits, rest of fraction, 15 bits.
|
||||
//
|
||||
// To round half to nearest even, there are 3 cases where we want to round
|
||||
// down (simply truncate the result of the bits away, which consists of
|
||||
// rounding bit and sticky bits) and two cases where we want to round up
|
||||
// (truncate then add one to the result).
|
||||
//
|
||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
||||
// truncates the last 16 bits away.
|
||||
//
|
||||
// To understand how it works, we can analyze this algorithm case by case:
|
||||
//
|
||||
// 1. L = 0, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input may create any carry, depending on
|
||||
// whether there is any value set to 1 in T bits.
|
||||
// - R may be set to 1 if there is a carry.
|
||||
// - L remains 0.
|
||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
||||
// this algorithm.
|
||||
//
|
||||
// 2. L = 1, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits but
|
||||
// adds 1 to rounding bit.
|
||||
// - L remains 1.
|
||||
//
|
||||
// 3. L = 0, R = 1, all of T are 0:
|
||||
// Expect: round down, this is exactly at half, the result is already
|
||||
// even (L=0).
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
||||
// doesn't create a carry.
|
||||
// - R remains 1.
|
||||
// - L remains 0.
|
||||
//
|
||||
// 4. L = 1, R = 1:
|
||||
// Expect: round up, this is exactly at half, the result needs to be
|
||||
// round to the next even number.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
||||
// creates a carry from rounding bit.
|
||||
// - The carry sets L to 0, creates another carry bit and propagate
|
||||
// forward to F bits.
|
||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
||||
// bits, which then creates the minimum value with the next exponent
|
||||
// value. Note that we won't have the case where exponents are all 1,
|
||||
// since that's either a NaN (handled in the other if condition) or inf
|
||||
// (handled in case 1).
|
||||
//
|
||||
// 5. L = 0, R = 1, any of T is 1:
|
||||
// Expect: round up, this is greater than half.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
||||
// sets rounding bit to 0, then create another carry.
|
||||
// - The second carry sets L to 1.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// Exact half value that is already even:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
||||
//
|
||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
||||
// carry is created into F and L:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
// Exact half value, round to next even number:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// which then propagates into L and F:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
//
|
||||
// Max denormal value round to min normal value:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||
//
|
||||
// Max normal value round to Inf:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||
//
|
||||
//
|
||||
// Least significant bit of resulting bfloat.
|
||||
uint32_t lsb = (input >> 16) & 1;
|
||||
uint32_t rounding_bias = 0x7fff + lsb;
|
||||
input += rounding_bias;
|
||||
output.value = static_cast<uint16_t>(input >> 16);
|
||||
if ((f.u & 0xff800000u) == 0) {
|
||||
// Flush positive denormal to 0
|
||||
output.value = 0x0;
|
||||
}
|
||||
if ((f.u & 0xff800000u) == 0x80000000u) {
|
||||
// Flush negative denormal to -0
|
||||
output.value = 0x8000;
|
||||
}
|
||||
if (float_isnan(v)) {
|
||||
output.value = NAN_VALUE;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
static bfloat16 epsilon() {
|
||||
bfloat16 x;
|
||||
x.value = 0x3c00; // 0x1.0p-7
|
||||
return x;
|
||||
}
|
||||
|
||||
static bfloat16 highest() {
|
||||
bfloat16 x;
|
||||
x.value = 0x7F7F; // 0x1.FEp127
|
||||
return x;
|
||||
}
|
||||
|
||||
static bfloat16 lowest() {
|
||||
bfloat16 x;
|
||||
x.value = 0xFF7F; // -0x1.FEp127
|
||||
return x;
|
||||
}
|
||||
|
||||
static bfloat16 min_positive_normal() {
|
||||
bfloat16 x;
|
||||
x.value = 0x0080; // 0x1p-126
|
||||
return x;
|
||||
}
|
||||
|
||||
bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; }
|
||||
|
||||
uint16_t value;
|
||||
|
||||
// A value that represents "not a number".
|
||||
static constexpr uint16_t NAN_VALUE = 0x7FC0;
|
||||
|
||||
private:
|
||||
// A value that represents "zero".
|
||||
static constexpr uint16_t ZERO_VALUE = 0;
|
||||
|
||||
B16_DEVICE_FUNC static bool float_isnan(const float& x) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return ::isnan(x);
|
||||
#else
|
||||
return std::isnan(x);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
|
||||
const bfloat16& dt) {
|
||||
os << static_cast<float>(dt);
|
||||
return os;
|
||||
}
|
||||
|
||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
|
||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
|
||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
|
||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
|
||||
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
|
||||
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
|
||||
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
|
||||
a.value ^= 0x8000;
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) < static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) <= static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) == static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) != static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
|
||||
a += bfloat16(1);
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
|
||||
a -= bfloat16(1);
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
|
||||
bfloat16 original_value = a;
|
||||
++a;
|
||||
return original_value;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
|
||||
bfloat16 original_value = a;
|
||||
--a;
|
||||
return original_value;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
typedef Eigen::bfloat16 bfloat16;
|
||||
} // end namespace tensorflow
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<tensorflow::bfloat16> {
|
||||
size_t operator()(const tensorflow::bfloat16& v) const {
|
||||
return hash<float>()(static_cast<float>(v));
|
||||
}
|
||||
};
|
||||
|
||||
using tensorflow::bfloat16;
|
||||
inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
|
||||
inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); }
|
||||
inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); }
|
||||
inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); }
|
||||
inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); }
|
||||
inline bfloat16 expm1(const bfloat16& a) {
|
||||
return bfloat16(std::expm1(float(a)));
|
||||
}
|
||||
inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); }
|
||||
inline bfloat16 log1p(const bfloat16& a) {
|
||||
return bfloat16(std::log1p(float(a)));
|
||||
}
|
||||
inline bfloat16 log10(const bfloat16& a) {
|
||||
return bfloat16(std::log10(float(a)));
|
||||
}
|
||||
inline bfloat16 sqrt(const bfloat16& a) {
|
||||
return bfloat16(std::sqrt(float(a)));
|
||||
}
|
||||
inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(std::pow(float(a), float(b)));
|
||||
}
|
||||
inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); }
|
||||
inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); }
|
||||
inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); }
|
||||
inline bfloat16 tanh(const bfloat16& a) {
|
||||
return bfloat16(std::tanh(float(a)));
|
||||
}
|
||||
inline bfloat16 floor(const bfloat16& a) {
|
||||
return bfloat16(std::floor(float(a)));
|
||||
}
|
||||
inline bfloat16 ceil(const bfloat16& a) {
|
||||
return bfloat16(std::ceil(float(a)));
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
|
||||
|
@ -1446,9 +1446,11 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
|
||||
}
|
||||
|
||||
auto size = (std::is_integral<T>::value
|
||||
? ((std::abs(limit - start) + std::abs(delta) - T(1)) /
|
||||
std::abs(delta))
|
||||
: (std::ceil(std::abs((limit - start) / delta))));
|
||||
? ((Eigen::numext::abs(limit - start) +
|
||||
Eigen::numext::abs(delta) - T(1)) /
|
||||
Eigen::numext::abs(delta))
|
||||
: (Eigen::numext::ceil(
|
||||
Eigen::numext::abs((limit - start) / delta))));
|
||||
c->set_output(0, c->Vector(static_cast<int64>(size)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -426,10 +426,10 @@ int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
||||
return 1;
|
||||
}
|
||||
// NaNs sort to the end.
|
||||
if (!std::isnan(x) && std::isnan(y)) {
|
||||
if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
|
||||
return -1;
|
||||
}
|
||||
if (std::isnan(x) && !std::isnan(y)) {
|
||||
if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
|
20
third_party/eigen3/gpu_packet_math.patch
vendored
20
third_party/eigen3/gpu_packet_math.patch
vendored
@ -22,3 +22,23 @@
|
||||
return res;
|
||||
}
|
||||
};
|
||||
--- a/Eigen/src/Core/arch/Default/BFloat16.h
|
||||
+++ a/Eigen/src/Core/arch/Default/BFloat16.h
|
||||
@@ -291,7 +291,7 @@
|
||||
return output;
|
||||
}
|
||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
||||
output.value = p[0];
|
||||
#else
|
||||
output.value = p[1];
|
||||
@@ -493,7 +493,7 @@
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
|
||||
float result = 0;
|
||||
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
|
||||
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
||||
q[0] = h.value;
|
||||
#else
|
||||
q[1] = h.value;
|
||||
|
Loading…
Reference in New Issue
Block a user