[XLA] Be more consistent about inferring mixed operand shapes
conv(a, b) should infer the same element type as conv(b, a). We did not have enough tie breaking logic to ensure this would happen. PiperOrigin-RevId: 351426686 Change-Id: Ibd7c0e9c17101c2b95a329c5b66d3b4e77aaae95
This commit is contained in:
parent
5fff2a4bca
commit
6dab73bceb
@ -15,8 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -31,9 +34,42 @@ int SignificandWidth(PrimitiveType type) {
|
||||
case F64:
|
||||
return std::numeric_limits<double>::digits;
|
||||
case BF16:
|
||||
return kBFloat16MantissaBits + 1;
|
||||
return std::numeric_limits<bfloat16>::digits;
|
||||
case F16:
|
||||
return 11;
|
||||
return std::numeric_limits<half>::digits;
|
||||
default:
|
||||
LOG(FATAL) << "Not a floating data type " << type;
|
||||
}
|
||||
}
|
||||
|
||||
int ExponentWidth(PrimitiveType type) {
|
||||
// Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
|
||||
// biased exponent and a trailing significand field.
|
||||
int total_bit_width = BitWidth(type);
|
||||
// This field contains all bits in the significand other than the leading
|
||||
// digit which is implied by the exponent.
|
||||
int trailing_significand_field_width = SignificandWidth(type) - 1;
|
||||
// The sign is encoded with a single bit.
|
||||
int kSignBitWidth = 1;
|
||||
// The remaining bits are used for encoding the biased exponent.
|
||||
return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
|
||||
}
|
||||
|
||||
int OverflowExponent(PrimitiveType type) {
|
||||
// |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
|
||||
// integer such that radix raised to the power one less than that integer is a
|
||||
// representable finite floating-point number." as such it does not actually
|
||||
// yield the maximum exponent but the exponent of the first integer which
|
||||
// overflows.
|
||||
switch (type) {
|
||||
case F32:
|
||||
return std::numeric_limits<float>::max_exponent;
|
||||
case F64:
|
||||
return std::numeric_limits<double>::max_exponent;
|
||||
case BF16:
|
||||
return std::numeric_limits<bfloat16>::max_exponent;
|
||||
case F16:
|
||||
return std::numeric_limits<half>::max_exponent;
|
||||
default:
|
||||
LOG(FATAL) << "Not a floating data type " << type;
|
||||
}
|
||||
|
@ -33,12 +33,13 @@ namespace primitive_util {
|
||||
// For non-float datatypes, results in a LOG(FATAL).
|
||||
int SignificandWidth(PrimitiveType type);
|
||||
|
||||
// The number of exponent bits in a BF16 value.
|
||||
const int kBFloat16ExponentBits = 8;
|
||||
// Returns the count of exponent bits for float datatypes.
|
||||
// For non-float datatypes, results in a LOG(FATAL).
|
||||
int ExponentWidth(PrimitiveType type);
|
||||
|
||||
// The number of mantissa bits in a BF16 value. There is an implicit leading
|
||||
// 1, so there is an implicit additional bit of precision.
|
||||
const int kBFloat16MantissaBits = 7;
|
||||
// Returns the exponent of the smallest number which cannot be represented.
|
||||
// For non-float datatypes, results in a LOG(FATAL).
|
||||
int OverflowExponent(PrimitiveType type);
|
||||
|
||||
// Returns the XLA primitive type (eg, F32) corresponding to the given
|
||||
// template parameter native type (eg, float).
|
||||
|
@ -42,5 +42,12 @@ TEST(PrimitiveUtilTest, StringToPrimitiveType) {
|
||||
EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status());
|
||||
}
|
||||
|
||||
TEST(PrimitiveUtilTest, FloatTypes) {
|
||||
EXPECT_EQ(primitive_util::SignificandWidth(F32), 24);
|
||||
EXPECT_EQ(primitive_util::SignificandWidth(BF16), 8);
|
||||
EXPECT_EQ(primitive_util::ExponentWidth(F32), 8);
|
||||
EXPECT_EQ(primitive_util::ExponentWidth(BF16), 8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -199,8 +199,9 @@ StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
|
||||
auto reduced_precision,
|
||||
EmitReducePrecisionIR(
|
||||
/*src_ty=*/F32, f32_value,
|
||||
/*dest_exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
||||
/*dest_mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b));
|
||||
/*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
|
||||
/*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
|
||||
b));
|
||||
auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
|
||||
auto shifted = b->CreateLShr(as_int32, 16);
|
||||
auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
|
||||
|
@ -233,11 +233,9 @@ StatusOr<PrimitiveType> MaybeUpcast(
|
||||
"`preferred_element_type` must have the same signedness as the "
|
||||
"original type.");
|
||||
}
|
||||
if (primitive_util::BitWidth(*preferred_element_type) <
|
||||
primitive_util::BitWidth(from_type)) {
|
||||
if (primitive_util::IsFloatingPointType(from_type)) {
|
||||
return from_type;
|
||||
}
|
||||
if (!primitive_util::IsFloatingPointType(from_type) &&
|
||||
primitive_util::BitWidth(*preferred_element_type) <
|
||||
primitive_util::BitWidth(from_type)) {
|
||||
return InvalidArgument(
|
||||
"`preferred_element_type` must not be narrower than the original "
|
||||
"type.");
|
||||
|
@ -657,6 +657,54 @@ ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
|
||||
return args;
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
|
||||
ConvolveArgs args = MakeConvolveArgs(BF16, F16);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/absl::nullopt))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
|
||||
ConvolveArgs args = MakeConvolveArgs(F16, BF16);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/absl::nullopt))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S32, U32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/absl::nullopt))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
|
||||
ConvolveArgs args = MakeConvolveArgs(U32, S32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/absl::nullopt))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
@ -690,7 +738,7 @@ TEST_F(ShapeInferenceTest,
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/BF16))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
@ -1714,7 +1762,7 @@ TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
|
||||
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/BF16));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
|
||||
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include <initializer_list>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "absl/base/macros.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
@ -266,21 +267,35 @@ class ShapeUtil {
|
||||
// and returns it.
|
||||
static PrimitiveType HigherPrecisionElementType(const Shape& a,
|
||||
const Shape& b) {
|
||||
if (SameElementType(a, b)) {
|
||||
// Returns a tuple where the elements are lexicographically ordered in terms
|
||||
// of importance.
|
||||
auto type_properties = [](const Shape& shape) {
|
||||
return std::make_tuple(
|
||||
// Prefer floating point types with more range over other
|
||||
// floating-point types or non-floating point types.
|
||||
ElementIsFloating(shape)
|
||||
? primitive_util::OverflowExponent(shape.element_type())
|
||||
: -1,
|
||||
// Prefer floating point types with more precision over less precise
|
||||
// types.
|
||||
ElementIsFloating(shape)
|
||||
? primitive_util::SignificandWidth(shape.element_type())
|
||||
: -1,
|
||||
// Prefer wider types over narrower types.
|
||||
primitive_util::BitWidth(shape.element_type()),
|
||||
// Prefer signed integer types over unsigned integer types.
|
||||
primitive_util::IsSignedIntegralType(shape.element_type()));
|
||||
};
|
||||
auto a_properties = type_properties(a);
|
||||
auto b_properties = type_properties(b);
|
||||
if (a_properties > b_properties) {
|
||||
return a.element_type();
|
||||
}
|
||||
// If only one of A and B are floating use the floating point type.
|
||||
if (ElementIsFloating(a) && !ElementIsFloating(b)) {
|
||||
return a.element_type();
|
||||
}
|
||||
if (ElementIsFloating(b) && !ElementIsFloating(a)) {
|
||||
if (b_properties > a_properties) {
|
||||
return b.element_type();
|
||||
}
|
||||
// Use the higher precision type.
|
||||
return primitive_util::BitWidth(a.element_type()) <
|
||||
primitive_util::BitWidth(b.element_type())
|
||||
? b.element_type()
|
||||
: a.element_type();
|
||||
CHECK(SameElementType(a, b));
|
||||
return a.element_type();
|
||||
}
|
||||
|
||||
// Returns true if the rank, dimension sizes, and element type are
|
||||
|
Loading…
Reference in New Issue
Block a user