[XLA] A variant of the Equal method ignoring fp precision.

PiperOrigin-RevId: 202113177
This commit is contained in:
Thomas Joerg 2018-06-26 05:30:55 -07:00 committed by TensorFlower Gardener
parent 69f147bd99
commit ed37c8a8a0
3 changed files with 44 additions and 5 deletions

View File

@ -94,8 +94,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (!ShapeUtil::SameElementType(lhs, rhs)) {
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
bool ignore_fp_precision) {
if ((ignore_fp_precision &&
!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
(!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
@ -103,7 +106,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (ShapeUtil::IsTuple(lhs)) {
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) {
return CompareShapes(l, r, compare_layouts);
return CompareShapes(l, r, compare_layouts,
ignore_fp_precision);
});
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
@ -170,7 +174,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
} // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
/*ignore_fp_precision=*/false);
if (!equal && VLOG_IS_ON(3)) {
VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
<< ", rhs = " << rhs.ShortDebugString();
@ -179,6 +184,18 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return equal;
}
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
/*ignore_fp_precision=*/true);
if (!equal && VLOG_IS_ON(3)) {
VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
}
return equal;
}
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
CHECK(ShapeUtil::IsArray(shape))
<< "Non-arrays do not have a rank, shape: " << shape;
@ -665,7 +682,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
return CompareShapes(lhs, rhs, /*compare_layouts=*/false);
return CompareShapes(lhs, rhs, /*compare_layouts=*/false,
/*ignore_fp_precision=*/false);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,

View File

@ -280,6 +280,9 @@ class ShapeUtil {
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
// Returns the rank (number of dimensions) of the given shape.
// Precondition: !IsTuple(shape)
static int64 Rank(const Shape& shape);

View File

@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
}
TEST(ShapeUtilTest, EqualIgnoringFpPrecision) {
EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
}
TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
}
TEST(ShapeUtilTest, CompatibleTuples) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});