diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index c932469c56a..6c015c84bab 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,8 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include + #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -31,9 +34,42 @@ int SignificandWidth(PrimitiveType type) { case F64: return std::numeric_limits::digits; case BF16: - return kBFloat16MantissaBits + 1; + return std::numeric_limits::digits; case F16: - return 11; + return std::numeric_limits::digits; + default: + LOG(FATAL) << "Not a floating data type " << type; + } +} + +int ExponentWidth(PrimitiveType type) { + // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a + // biased exponent and a trailing significand field. + int total_bit_width = BitWidth(type); + // This field contains all bits in the significand other than the leading + // digit which is implied by the exponent. + int trailing_significand_field_width = SignificandWidth(type) - 1; + // The sign is encoded with a single bit. + int kSignBitWidth = 1; + // The remaining bits are used for encoding the biased exponent. + return total_bit_width - (trailing_significand_field_width + kSignBitWidth); +} + +int OverflowExponent(PrimitiveType type) { + // |std::numeric_limits::max_exponent| is defined as: "Maximum positive + // integer such that radix raised to the power one less than that integer is a + // representable finite floating-point number." as such it does not actually + // yield the maximum exponent but the exponent of the first integer which + // overflows. + switch (type) { + case F32: + return std::numeric_limits::max_exponent; + case F64: + return std::numeric_limits::max_exponent; + case BF16: + return std::numeric_limits::max_exponent; + case F16: + return std::numeric_limits::max_exponent; default: LOG(FATAL) << "Not a floating data type " << type; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 1228b4f9a32..0e3bdfdd4d0 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -33,12 +33,13 @@ namespace primitive_util { // For non-float datatypes, results in a LOG(FATAL). int SignificandWidth(PrimitiveType type); -// The number of exponent bits in a BF16 value. -const int kBFloat16ExponentBits = 8; +// Returns the count of exponent bits for float datatypes. +// For non-float datatypes, results in a LOG(FATAL). +int ExponentWidth(PrimitiveType type); -// The number of mantissa bits in a BF16 value. There is an implicit leading -// 1, so there is an implicit additional bit of precision. -const int kBFloat16MantissaBits = 7; +// Returns the exponent of the smallest number which cannot be represented. +// For non-float datatypes, results in a LOG(FATAL). +int OverflowExponent(PrimitiveType type); // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc index 1f765d6da9e..0186b9dc4c8 100644 --- a/tensorflow/compiler/xla/primitive_util_test.cc +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -42,5 +42,12 @@ TEST(PrimitiveUtilTest, StringToPrimitiveType) { EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); } +TEST(PrimitiveUtilTest, FloatTypes) { + EXPECT_EQ(primitive_util::SignificandWidth(F32), 24); + EXPECT_EQ(primitive_util::SignificandWidth(BF16), 8); + EXPECT_EQ(primitive_util::ExponentWidth(F32), 8); + EXPECT_EQ(primitive_util::ExponentWidth(BF16), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 51983a578a9..735b0b71818 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -199,8 +199,9 @@ StatusOr EmitF32ToBF16(llvm::Value* f32_value, auto reduced_precision, EmitReducePrecisionIR( /*src_ty=*/F32, f32_value, - /*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits, - /*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b)); + /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16), + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1, + b)); auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty()); auto shifted = b->CreateLShr(as_int32, 16); auto truncated = b->CreateTrunc(shifted, b->getInt16Ty()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index badfe81625e..669388e62cd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -233,11 +233,9 @@ StatusOr MaybeUpcast( "`preferred_element_type` must have the same signedness as the " "original type."); } - if (primitive_util::BitWidth(*preferred_element_type) < - primitive_util::BitWidth(from_type)) { - if (primitive_util::IsFloatingPointType(from_type)) { - return from_type; - } + if (!primitive_util::IsFloatingPointType(from_type) && + primitive_util::BitWidth(*preferred_element_type) < + primitive_util::BitWidth(from_type)) { return InvalidArgument( "`preferred_element_type` must not be narrower than the original " "type."); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 77f84a69205..2f9029e80aa 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -657,6 +657,54 @@ ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) { return args; } +TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) { + ConvolveArgs args = MakeConvolveArgs(BF16, F16); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/absl::nullopt)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) { + ConvolveArgs args = MakeConvolveArgs(F16, BF16); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/absl::nullopt)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) { + ConvolveArgs args = MakeConvolveArgs(S32, U32); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/absl::nullopt)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) { + ConvolveArgs args = MakeConvolveArgs(U32, S32); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/absl::nullopt)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), + inferred_shape)); +} + TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( @@ -690,7 +738,7 @@ TEST_F(ShapeInferenceTest, args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/BF16)) - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), inferred_shape)); } @@ -1714,7 +1762,7 @@ TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) { ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, /*preferred_element_type=*/BF16)); EXPECT_TRUE( - ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32}))); + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32}))); } TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 584d948e92c..ec9bae16bf0 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include "absl/base/macros.h" #include "absl/container/inlined_vector.h" @@ -266,21 +267,35 @@ class ShapeUtil { // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { - if (SameElementType(a, b)) { + // Returns a tuple where the elements are lexicographically ordered in terms + // of importance. + auto type_properties = [](const Shape& shape) { + return std::make_tuple( + // Prefer floating point types with more range over other + // floating-point types or non-floating point types. + ElementIsFloating(shape) + ? primitive_util::OverflowExponent(shape.element_type()) + : -1, + // Prefer floating point types with more precision over less precise + // types. + ElementIsFloating(shape) + ? primitive_util::SignificandWidth(shape.element_type()) + : -1, + // Prefer wider types over narrower types. + primitive_util::BitWidth(shape.element_type()), + // Prefer signed integer types over unsigned integer types. + primitive_util::IsSignedIntegralType(shape.element_type())); + }; + auto a_properties = type_properties(a); + auto b_properties = type_properties(b); + if (a_properties > b_properties) { return a.element_type(); } - // If only one of A and B are floating use the floating point type. - if (ElementIsFloating(a) && !ElementIsFloating(b)) { - return a.element_type(); - } - if (ElementIsFloating(b) && !ElementIsFloating(a)) { + if (b_properties > a_properties) { return b.element_type(); } - // Use the higher precision type. - return primitive_util::BitWidth(a.element_type()) < - primitive_util::BitWidth(b.element_type()) - ? b.element_type() - : a.element_type(); + CHECK(SameElementType(a, b)); + return a.element_type(); } // Returns true if the rank, dimension sizes, and element type are