diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 97ced5dfdc1..b9407818cd8 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -905,7 +905,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::Compatible(operand_shape, hlo.shape())) { + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { return target_index; } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index c0a0e13f073..74fa0b2f2e7 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -553,6 +553,16 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); + } + return SameDimensions(lhs, rhs); +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82a513a65ad..2ea1bd95cb5 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -190,6 +190,11 @@ class ShapeUtil { // compatibility. static bool Compatible(const Shape& lhs, const Shape& rhs); + // Returns true if the rank and dimension sizes are identical. Element type + // and layout are ignored. Tuple elements are compared recursively for + // compatibility. + static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0ba542ad1be..4bce7ca51d0 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { @@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {