[XLA] A variant of the Equal method ignoring fp precision.
PiperOrigin-RevId: 202113177
This commit is contained in:
parent
69f147bd99
commit
ed37c8a8a0
@ -94,8 +94,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
|
||||
// Recursive helper for comparing the equality of two shapes. Returns true if
|
||||
// the shapes are the same. If compare_layouts is true, then layouts must also
|
||||
// match.
|
||||
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
if (!ShapeUtil::SameElementType(lhs, rhs)) {
|
||||
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
|
||||
bool ignore_fp_precision) {
|
||||
if ((ignore_fp_precision &&
|
||||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
|
||||
(!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) {
|
||||
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
|
||||
return false;
|
||||
}
|
||||
@ -103,7 +106,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
if (ShapeUtil::IsTuple(lhs)) {
|
||||
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
[=](const Shape& l, const Shape& r) {
|
||||
return CompareShapes(l, r, compare_layouts);
|
||||
return CompareShapes(l, r, compare_layouts,
|
||||
ignore_fp_precision);
|
||||
});
|
||||
} else if (!ShapeUtil::IsArray(lhs)) {
|
||||
// Non-tuple, non-array tupes such as opaque and token types are trivially
|
||||
@ -170,7 +174,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
} // namespace
|
||||
|
||||
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
|
||||
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
|
||||
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
|
||||
/*ignore_fp_precision=*/false);
|
||||
if (!equal && VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
|
||||
<< ", rhs = " << rhs.ShortDebugString();
|
||||
@ -179,6 +184,18 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
|
||||
/*ignore_fp_precision=*/true);
|
||||
if (!equal && VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
|
||||
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
|
||||
}
|
||||
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
|
||||
CHECK(ShapeUtil::IsArray(shape))
|
||||
<< "Non-arrays do not have a rank, shape: " << shape;
|
||||
@ -665,7 +682,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
|
||||
return CompareShapes(lhs, rhs, /*compare_layouts=*/false);
|
||||
return CompareShapes(lhs, rhs, /*compare_layouts=*/false,
|
||||
/*ignore_fp_precision=*/false);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
|
||||
|
@ -280,6 +280,9 @@ class ShapeUtil {
|
||||
// Returns whether the lhs and rhs shapes are identical protobufs.
|
||||
static bool Equal(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
|
||||
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns the rank (number of dimensions) of the given shape.
|
||||
// Precondition: !IsTuple(shape)
|
||||
static int64 Rank(const Shape& shape);
|
||||
|
@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
|
||||
EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, EqualIgnoringFpPrecision) {
|
||||
EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
|
||||
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
|
||||
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
|
||||
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleTuples) {
|
||||
Shape tuple1 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
|
||||
|
Loading…
Reference in New Issue
Block a user