diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 0833919b124..c9f37bdc430 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1074,7 +1074,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, } } - return std::make_tuple(true, deleted_indices, inserted_indices); + return std::make_tuple(!deleted_indices.empty() || !inserted_indices.empty(), + deleted_indices, inserted_indices); } /* static */ std::vector> diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4e2030667ee..414b53d4f67 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -558,6 +558,8 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1))); EXPECT_FALSE(std::get<0>( ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); + EXPECT_FALSE(std::get<0>( + ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape0))); } TEST(ShapeUtilTest, ForEachIndex) {