From ed37c8a8a07734a4eb13e14d7d7b67c81a2968b7 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Tue, 26 Jun 2018 05:30:55 -0700 Subject: [PATCH] [XLA] A variant of the Equal method ignoring fp precision. PiperOrigin-RevId: 202113177 --- tensorflow/compiler/xla/shape_util.cc | 28 ++++++++++++++++++---- tensorflow/compiler/xla/shape_util.h | 3 +++ tensorflow/compiler/xla/shape_util_test.cc | 18 ++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 98c3095499f..dda72b5e75e 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -94,8 +94,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (!ShapeUtil::SameElementType(lhs, rhs)) { +bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, + bool ignore_fp_precision) { + if ((ignore_fp_precision && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { VLOG(3) << "CompareShapes: lhs element type != rhs element type"; return false; } @@ -103,7 +106,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (ShapeUtil::IsTuple(lhs)) { return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially @@ -170,7 +174,8 @@ StatusOr MakeShapeWithLayoutInternal( } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/false); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -179,6 +184,18 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/true); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ int64 ShapeUtil::Rank(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)) << "Non-arrays do not have a rank, shape: " << shape; @@ -665,7 +682,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false); + return CompareShapes(lhs, rhs, /*compare_layouts=*/false, + /*ignore_fp_precision=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 02e4f41505f..2840d003333 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -280,6 +280,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. + static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) static int64 Rank(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 606f7492cea..b6f30af381d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, EqualIgnoringFpPrecision) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});