From c0a4c7ffc2ebb06c544a0764a5812fdc2ddba336 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Mon, 11 Sep 2017 20:10:28 -0700 Subject: [PATCH] [XLA] Fix bug in ShapeUtil::ShapeIs that would lead to type inference errors. PiperOrigin-RevId: 168323589 --- tensorflow/compiler/xla/shape_util.cc | 16 +--------------- tensorflow/compiler/xla/shape_util.h | 2 ++ tensorflow/compiler/xla/shape_util_test.cc | 4 ++++ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b71b3a9e131..dc46e2bbe9c 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -299,21 +299,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { - TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); - if (shape.element_type() != element_type) { - return false; - } - if (shape.dimensions_size() != Rank(shape)) { - return false; - } - int64 i = 0; - for (int64 dimension : dimensions) { - if (shape.dimensions(i) != dimension) { - return false; - } - i += 1; - } - return true; + return Equal(shape, MakeShape(element_type, dimensions)); } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index e3473138376..6de61e5e045 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -296,6 +296,8 @@ class ShapeUtil { // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. + // + // DEPRECATED: Use Equal() instead. static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 9635e5ad2eb..79945b9c772 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -491,6 +491,10 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } +TEST(ShapeUtilTest, ShapeIs) { + EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); +} + TEST(ShapeUtilTest, ForEachIndex) { struct ShapeDimensionAndNumberInvocations { std::vector dimensions;