Replace tensorflow::bfloat16 with Eigen::bfloat16
This commit is contained in:
parent
aed5e3ea00
commit
3451f21c1e
@ -28,20 +28,20 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
|
||||||
@ -897,7 +897,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) {
|
|||||||
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
} else if (element_ty.isBF16()) {
|
} else if (element_ty.isBF16()) {
|
||||||
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
|
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::bfloat16>::epsilon().value;
|
||||||
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
||||||
return DenseElementsAttr::get(scalar_ty, value);
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
} else if (element_ty.isF32()) {
|
} else if (element_ty.isF32()) {
|
||||||
|
@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
builder,
|
builder,
|
||||||
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
|
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, static_cast<Eigen::bfloat16>(
|
||||||
|
Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
|
||||||
case F64:
|
case F64:
|
||||||
@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
Eigen::NumTraits<Eigen::half>::lowest());
|
Eigen::NumTraits<Eigen::half>::lowest());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::lowest());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
|
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
|
||||||
case F64:
|
case F64:
|
||||||
@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
std::numeric_limits<Eigen::half>::min());
|
std::numeric_limits<Eigen::half>::min());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, std::numeric_limits<Eigen::bfloat16>::min());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
|
||||||
case F64:
|
case F64:
|
||||||
@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(builder,
|
return ConstantR0<Eigen::half>(builder,
|
||||||
Eigen::NumTraits<Eigen::half>::highest());
|
Eigen::NumTraits<Eigen::half>::highest());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(builder, bfloat16::highest());
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
|
builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
|
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
|
||||||
case F64:
|
case F64:
|
||||||
@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
|
|||||||
return ConstantR0<Eigen::half>(
|
return ConstantR0<Eigen::half>(
|
||||||
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
|
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
|
||||||
case BF16:
|
case BF16:
|
||||||
return ConstantR0<bfloat16>(
|
return ConstantR0<Eigen::bfloat16>(
|
||||||
builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
|
builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
|
||||||
case F32:
|
case F32:
|
||||||
return ConstantR0<float>(builder,
|
return ConstantR0<float>(builder,
|
||||||
std::numeric_limits<float>::quiet_NaN());
|
std::numeric_limits<float>::quiet_NaN());
|
||||||
|
@ -2678,7 +2678,9 @@ struct MinMaxFiniteValue<Eigen::half> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct MinMaxFiniteValue<bfloat16> {
|
struct MinMaxFiniteValue<bfloat16> {
|
||||||
static double max() { return static_cast<double>(bfloat16::highest()); }
|
static double max() {
|
||||||
|
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
|
||||||
|
}
|
||||||
static double min() { return -max(); }
|
static double min() { return -max(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
::testing::Values(
|
::testing::Values(
|
||||||
// The largest negative number smaller than zero in bf16 that's not
|
// The largest negative number smaller than zero in bf16 that's not
|
||||||
// denormalized.
|
// denormalized.
|
||||||
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
|
std::make_pair(static_cast<float>(
|
||||||
|
-std::numeric_limits<Eigen::bfloat16>::min()),
|
||||||
0.0f),
|
0.0f),
|
||||||
// Test odd and even values.
|
// Test odd and even values.
|
||||||
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
||||||
|
@ -35,14 +35,16 @@ TEST(Bfloat16Test, FlushDenormalsToZero) {
|
|||||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
||||||
denorm < std::numeric_limits<float>::denorm_min();
|
denorm < std::numeric_limits<float>::denorm_min();
|
||||||
denorm = std::nextafterf(denorm, 1.0f)) {
|
denorm = std::nextafterf(denorm, 1.0f)) {
|
||||||
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
|
bfloat16 bf_trunc =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16(denorm));
|
||||||
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
|
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
|
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
|
||||||
} else {
|
} else {
|
||||||
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
|
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
|
||||||
}
|
}
|
||||||
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
|
bfloat16 bf_round =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm));
|
||||||
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
|
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
|
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
|
||||||
@ -88,7 +90,8 @@ class Bfloat16Test : public ::testing::Test,
|
|||||||
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
||||||
|
|
||||||
TEST_P(Bfloat16Test, TruncateTest) {
|
TEST_P(Bfloat16Test, TruncateTest) {
|
||||||
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
|
bfloat16 truncated =
|
||||||
|
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16((GetParam().input)));
|
||||||
|
|
||||||
if (std::isnan(GetParam().input)) {
|
if (std::isnan(GetParam().input)) {
|
||||||
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
|
||||||
@ -97,7 +100,8 @@ TEST_P(Bfloat16Test, TruncateTest) {
|
|||||||
|
|
||||||
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
|
||||||
|
|
||||||
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
|
bfloat16 rounded = bfloat16(
|
||||||
|
Eigen::bfloat16_impl::float_to_bfloat16_rtne((GetParam().input)));
|
||||||
if (std::isnan(GetParam().input)) {
|
if (std::isnan(GetParam().input)) {
|
||||||
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
|
||||||
return;
|
return;
|
||||||
@ -172,9 +176,13 @@ TEST(Bfloat16Test, Conversion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(Bfloat16Test, Epsilon) {
|
TEST(Bfloat16Test, Epsilon) {
|
||||||
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
|
EXPECT_LT(1.0f,
|
||||||
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
|
static_cast<float>(Eigen::NumTraits<Eigen::bfloat16>::epsilon() +
|
||||||
bfloat16(1.0f)));
|
bfloat16(1.0f)));
|
||||||
|
EXPECT_EQ(1.0f,
|
||||||
|
static_cast<float>((Eigen::NumTraits<Eigen::bfloat16>::epsilon() /
|
||||||
|
bfloat16(2.0f)) +
|
||||||
|
bfloat16(1.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Bfloat16Test, Negate) {
|
TEST(Bfloat16Test, Negate) {
|
||||||
|
@ -43,47 +43,17 @@ typedef Eigen::QUInt16 quint16;
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
|
static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||||
reinterpret_cast<uint16_t*>(&float_val));
|
reinterpret_cast<uint16_t*>(&float_val));
|
||||||
#else
|
#else
|
||||||
return *reinterpret_cast<tensorflow::bfloat16*>(
|
return *reinterpret_cast<tensorflow::bfloat16*>(
|
||||||
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
|
|
||||||
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
|
|
||||||
template <>
|
|
||||||
struct NumTraits<tensorflow::bfloat16>
|
|
||||||
: GenericNumTraits<tensorflow::bfloat16> {
|
|
||||||
enum {
|
|
||||||
IsInteger = 0,
|
|
||||||
IsSigned = 1,
|
|
||||||
RequireInitialization = 0
|
|
||||||
};
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::highest());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::lowest());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::infinity());
|
|
||||||
}
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
|
|
||||||
return FloatToBFloat16(NumTraits<float>::quiet_NaN());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
||||||
enum {
|
enum {
|
||||||
@ -104,30 +74,6 @@ struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
|
|||||||
static inline tensorflow::tstring quiet_NaN();
|
static inline tensorflow::tstring quiet_NaN();
|
||||||
};
|
};
|
||||||
|
|
||||||
using ::tensorflow::operator==;
|
|
||||||
using ::tensorflow::operator!=;
|
|
||||||
|
|
||||||
namespace numext {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
|
|
||||||
const tensorflow::bfloat16& x) {
|
|
||||||
return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace numext
|
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
|
@ -16,13 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/bfloat16.h"
|
#include "tensorflow/core/framework/bfloat16.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
#include "tensorflow/core/platform/byte_order.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
// Note that the GPU cast functor templates need to be instantiated unlike the
|
// Note that the GPU cast functor templates need to be instantiated unlike the
|
||||||
// CPU ones, and hence their specializations are different than that for CPUs.
|
// CPU ones, and hence their specializations are different than that for CPUs.
|
||||||
@ -278,48 +278,6 @@ template <typename From, typename To>
|
|||||||
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
|
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
|
||||||
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
|
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
|
||||||
|
|
||||||
// Specialized cast op impls for bfloat16.
|
|
||||||
template <>
|
|
||||||
struct scalar_cast_op<::tensorflow::bfloat16, float> {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
|
||||||
typedef float result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
|
|
||||||
const ::tensorflow::bfloat16& a) const {
|
|
||||||
float ret;
|
|
||||||
uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
p[0] = a.value;
|
|
||||||
p[1] = 0;
|
|
||||||
#else
|
|
||||||
static_assert(::tensorflow::port::kLittleEndian,
|
|
||||||
"Not a little endian system!");
|
|
||||||
p[0] = 0;
|
|
||||||
p[1] = a.value;
|
|
||||||
#endif
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
|
|
||||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct scalar_cast_op<float, ::tensorflow::bfloat16> {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
|
||||||
typedef ::tensorflow::bfloat16 result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
|
|
||||||
const float a) const {
|
|
||||||
return ::tensorflow::bfloat16(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
|
|
||||||
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/framework/bfloat16.h"
|
#include "tensorflow/core/framework/bfloat16.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@ -37,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||||
#include "include/libxsmm_intrinsics_x86.h"
|
#include "include/libxsmm_intrinsics_x86.h"
|
||||||
#include "include/libxsmm_malloc.h"
|
#include "include/libxsmm_malloc.h"
|
||||||
@ -165,7 +165,7 @@ bool IsZero(T v);
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
ALWAYS_INLINE bool IsZero(bfloat16 v) {
|
ALWAYS_INLINE bool IsZero(bfloat16 v) {
|
||||||
return v.IsZero();
|
return !static_cast<bool>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -17,12 +17,4 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {} // end namespace tensorflow
|
||||||
|
|
||||||
const uint16_t bfloat16::NAN_VALUE;
|
|
||||||
const uint16_t bfloat16::ZERO_VALUE;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
|
|
||||||
return static_cast<Eigen::half>(float(*this));
|
|
||||||
}
|
|
||||||
} // end namespace tensorflow
|
|
||||||
|
@ -22,479 +22,18 @@ limitations under the License.
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
#include "tensorflow/core/platform/byte_order.h"
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#if defined(__CUDACC__) || (defined(__HIPCC__) && defined(__HIP__))
|
|
||||||
// All functions callable from CUDA code must be qualified with __device__
|
|
||||||
#define B16_DEVICE_FUNC __host__ __device__
|
|
||||||
|
|
||||||
#else
|
|
||||||
#define B16_DEVICE_FUNC
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
struct bfloat16;
|
||||||
struct half;
|
struct half;
|
||||||
}
|
} // namespace Eigen
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
typedef Eigen::bfloat16 bfloat16;
|
||||||
// Single precision complex.
|
|
||||||
typedef std::complex<float> complex64;
|
|
||||||
// Double precision complex.
|
|
||||||
typedef std::complex<double> complex128;
|
|
||||||
|
|
||||||
// see framework/bfloat16.h for description.
|
|
||||||
struct bfloat16 {
|
|
||||||
// The default constructor must yield a zero value, not an uninitialized
|
|
||||||
// value; some TF kernels use T() as a zero value.
|
|
||||||
B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
|
|
||||||
bfloat16 output;
|
|
||||||
if (float_isnan(v)) {
|
|
||||||
output.value = NAN_VALUE;
|
|
||||||
return output;
|
|
||||||
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
|
|
||||||
// Flush denormal to +/- 0.
|
|
||||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
output.value = p[0];
|
|
||||||
#else
|
|
||||||
output.value = p[1];
|
|
||||||
#endif
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const float v) {
|
|
||||||
value = round_to_bfloat16(v).value;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const double val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
// Following the convention of numpy, converting between complex and
|
|
||||||
// float will lead to loss of imag value.
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
|
|
||||||
: bfloat16(val.real()) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
|
|
||||||
: bfloat16(static_cast<float>(val.real())) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const int val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const long val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const long long val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
B16_DEVICE_FUNC explicit bfloat16(const T& val)
|
|
||||||
: bfloat16(static_cast<float>(val)) {}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator float() const {
|
|
||||||
float result = 0;
|
|
||||||
|
|
||||||
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
|
|
||||||
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
q[0] = value;
|
|
||||||
#else
|
|
||||||
q[1] = value;
|
|
||||||
#endif
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator bool() const {
|
|
||||||
return static_cast<bool>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator Eigen::half() const;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator short() const {
|
|
||||||
return static_cast<short>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator int() const {
|
|
||||||
return static_cast<int>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator long() const {
|
|
||||||
return static_cast<long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator char() const {
|
|
||||||
return static_cast<char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator signed char() const {
|
|
||||||
return static_cast<signed char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned char() const {
|
|
||||||
return static_cast<unsigned char>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned short() const {
|
|
||||||
return static_cast<unsigned short>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned int() const {
|
|
||||||
return static_cast<unsigned int>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned long() const {
|
|
||||||
return static_cast<unsigned long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator unsigned long long() const {
|
|
||||||
return static_cast<unsigned long long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator long long() const {
|
|
||||||
return static_cast<long long>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator double() const {
|
|
||||||
return static_cast<double>(float(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator complex64() const {
|
|
||||||
return complex64(float(*this), float(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC explicit operator complex128() const {
|
|
||||||
return complex128(double(*this), double(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
union FP32 {
|
|
||||||
unsigned int u;
|
|
||||||
float f;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
|
|
||||||
// method.
|
|
||||||
// TODO: There is a slightly faster implementation (8% faster on CPU)
|
|
||||||
// than this (documented in cl/175987786), that is exponentially harder to
|
|
||||||
// understand and document. Switch to the faster version when converting to
|
|
||||||
// BF16 becomes compute-bound.
|
|
||||||
B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
|
|
||||||
uint32_t input;
|
|
||||||
FP32 f;
|
|
||||||
f.f = v;
|
|
||||||
input = f.u;
|
|
||||||
bfloat16 output;
|
|
||||||
|
|
||||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
|
||||||
// reduces expected error when we convert a large number of floats. Here
|
|
||||||
// is how it works:
|
|
||||||
//
|
|
||||||
// Definitions:
|
|
||||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
|
||||||
// with the following tags:
|
|
||||||
//
|
|
||||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
|
||||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
|
||||||
//
|
|
||||||
// S: Sign bit.
|
|
||||||
// E: Exponent bits.
|
|
||||||
// F: First 6 bits of fraction.
|
|
||||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
|
||||||
// rest of the float32. This is also the 7th bit of fraction
|
|
||||||
// R: Rounding bit, 8th bit of fraction.
|
|
||||||
// T: Sticky bits, rest of fraction, 15 bits.
|
|
||||||
//
|
|
||||||
// To round half to nearest even, there are 3 cases where we want to round
|
|
||||||
// down (simply truncate the result of the bits away, which consists of
|
|
||||||
// rounding bit and sticky bits) and two cases where we want to round up
|
|
||||||
// (truncate then add one to the result).
|
|
||||||
//
|
|
||||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
|
||||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
|
||||||
// truncates the last 16 bits away.
|
|
||||||
//
|
|
||||||
// To understand how it works, we can analyze this algorithm case by case:
|
|
||||||
//
|
|
||||||
// 1. L = 0, R = 0:
|
|
||||||
// Expect: round down, this is less than half value.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input may create any carry, depending on
|
|
||||||
// whether there is any value set to 1 in T bits.
|
|
||||||
// - R may be set to 1 if there is a carry.
|
|
||||||
// - L remains 0.
|
|
||||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
|
||||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
|
||||||
// this algorithm.
|
|
||||||
//
|
|
||||||
// 2. L = 1, R = 0:
|
|
||||||
// Expect: round down, this is less than half value.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
|
||||||
// - Adding rounding bias to input doesn't change sticky bits but
|
|
||||||
// adds 1 to rounding bit.
|
|
||||||
// - L remains 1.
|
|
||||||
//
|
|
||||||
// 3. L = 0, R = 1, all of T are 0:
|
|
||||||
// Expect: round down, this is exactly at half, the result is already
|
|
||||||
// even (L=0).
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
|
||||||
// doesn't create a carry.
|
|
||||||
// - R remains 1.
|
|
||||||
// - L remains 0.
|
|
||||||
//
|
|
||||||
// 4. L = 1, R = 1:
|
|
||||||
// Expect: round up, this is exactly at half, the result needs to be
|
|
||||||
// round to the next even number.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
|
||||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
|
||||||
// creates a carry from rounding bit.
|
|
||||||
// - The carry sets L to 0, creates another carry bit and propagate
|
|
||||||
// forward to F bits.
|
|
||||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
|
||||||
// bits, which then creates the minimum value with the next exponent
|
|
||||||
// value. Note that we won't have the case where exponents are all 1,
|
|
||||||
// since that's either a NaN (handled in the other if condition) or inf
|
|
||||||
// (handled in case 1).
|
|
||||||
//
|
|
||||||
// 5. L = 0, R = 1, any of T is 1:
|
|
||||||
// Expect: round up, this is greater than half.
|
|
||||||
//
|
|
||||||
// Algorithm:
|
|
||||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
|
||||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
|
||||||
// sets rounding bit to 0, then create another carry.
|
|
||||||
// - The second carry sets L to 1.
|
|
||||||
//
|
|
||||||
// Examples:
|
|
||||||
//
|
|
||||||
// Exact half value that is already even:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
|
||||||
//
|
|
||||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
|
||||||
// carry is created into F and L:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
|
||||||
//
|
|
||||||
// Exact half value, round to next even number:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// which then propagates into L and F:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// Max denormal value round to min normal value:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// propagate into L and F, which then propagates into exponent
|
|
||||||
// bits:
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
|
||||||
//
|
|
||||||
// Max normal value round to Inf:
|
|
||||||
// Input:
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
|
||||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
|
||||||
//
|
|
||||||
// This falls into case 4. We create a carry from R and T,
|
|
||||||
// propagate into L and F, which then propagates into exponent
|
|
||||||
// bits:
|
|
||||||
//
|
|
||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
|
||||||
// S E E E E E E E E F F F F F F L
|
|
||||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// Least significant bit of resulting bfloat.
|
|
||||||
uint32_t lsb = (input >> 16) & 1;
|
|
||||||
uint32_t rounding_bias = 0x7fff + lsb;
|
|
||||||
input += rounding_bias;
|
|
||||||
output.value = static_cast<uint16_t>(input >> 16);
|
|
||||||
if ((f.u & 0xff800000u) == 0) {
|
|
||||||
// Flush positive denormal to 0
|
|
||||||
output.value = 0x0;
|
|
||||||
}
|
|
||||||
if ((f.u & 0xff800000u) == 0x80000000u) {
|
|
||||||
// Flush negative denormal to -0
|
|
||||||
output.value = 0x8000;
|
|
||||||
}
|
|
||||||
if (float_isnan(v)) {
|
|
||||||
output.value = NAN_VALUE;
|
|
||||||
}
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 epsilon() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x3c00; // 0x1.0p-7
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 highest() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x7F7F; // 0x1.FEp127
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 lowest() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0xFF7F; // -0x1.FEp127
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bfloat16 min_positive_normal() {
|
|
||||||
bfloat16 x;
|
|
||||||
x.value = 0x0080; // 0x1p-126
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; }
|
|
||||||
|
|
||||||
uint16_t value;
|
|
||||||
|
|
||||||
// A value that represents "not a number".
|
|
||||||
static constexpr uint16_t NAN_VALUE = 0x7FC0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// A value that represents "zero".
|
|
||||||
static constexpr uint16_t ZERO_VALUE = 0;
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC static bool float_isnan(const float& x) {
|
|
||||||
#ifdef __CUDA_ARCH__
|
|
||||||
return ::isnan(x);
|
|
||||||
#else
|
|
||||||
return std::isnan(x);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
|
|
||||||
const bfloat16& dt) {
|
|
||||||
os << static_cast<float>(dt);
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
|
|
||||||
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
|
|
||||||
a.value ^= 0x8000;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) < static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) <= static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) == static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) != static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) > static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
|
|
||||||
return static_cast<float>(a) >= static_cast<float>(b);
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a + b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a - b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
|
|
||||||
a += bfloat16(1);
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
|
|
||||||
a -= bfloat16(1);
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
|
|
||||||
bfloat16 original_value = a;
|
|
||||||
++a;
|
|
||||||
return original_value;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
|
|
||||||
bfloat16 original_value = a;
|
|
||||||
--a;
|
|
||||||
return original_value;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a * b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
|
|
||||||
a = a / b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
template <>
|
|
||||||
struct hash<tensorflow::bfloat16> {
|
|
||||||
size_t operator()(const tensorflow::bfloat16& v) const {
|
|
||||||
return hash<float>()(static_cast<float>(v));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using tensorflow::bfloat16;
|
using tensorflow::bfloat16;
|
||||||
inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
|
inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user