From 9be9961059473c85884b80419a402637a716c4cc Mon Sep 17 00:00:00 2001 From: Alfie Edwards Date: Mon, 30 Nov 2020 14:22:04 +0000 Subject: [PATCH] relaxing requirements for clonewithnewoperands to preserve sharding information Reviewers: #tensorflow!, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved! Differential Revision: https://phabricator.sourcevertex.net/D36993 --- .../compiler/xla/service/hlo_instruction.cc | 11 +++--- .../xla/service/hlo_instruction_test.cc | 39 +++++++++++++++++++ tensorflow/compiler/xla/shape.cc | 8 ++-- tensorflow/compiler/xla/shape.h | 5 +++ tensorflow/compiler/xla/shape_util.cc | 6 +++ tensorflow/compiler/xla/shape_util.h | 6 +++ 6 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8d33664e38e..edbc305d926 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1497,11 +1497,12 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) { void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { - if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType( - shape_, derived_instruction->shape())) { - // Only copy sharding if the shape of the two instruction is compatible - // because copying it between differently shaped instructions can produce - // invalid shardings. + if (sharding_ != nullptr && + ShapeUtil::CompatibleIgnoringElementTypeAndDimensions( + shape_, derived_instruction->shape())) { + // Only copy sharding if the tuple tree shape of the two instruction is + // compatible because copying it between differently shaped instructions + // can produce invalid shardings. derived_instruction->set_sharding(*sharding_); } else { derived_instruction->clear_sharding(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index e5735bea843..37c1282a342 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -749,6 +749,45 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape())); } +TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) { + + HloSharding sharding = HloSharding::AssignDevice(5); + HloComputation::Builder builder(TestName()); + auto* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + tuple->set_sharding(sharding); + // Compatible with original shape as tuple tree structure is identical + auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape}); + auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {}); + EXPECT_EQ(tuple_clone->sharding(), sharding); +} + +TEST_F(HloInstructionTest, DoNotPreserveShardingThroughIncompatibleClone) { + + HloSharding sharding = HloSharding::AssignDevice(5); + HloComputation::Builder builder(TestName()); + auto* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + tuple->set_sharding(sharding); + // Incompatible with original shape as tuple tree structure is different + auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape, + clone_shape}); + auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {}); + EXPECT_FALSE(tuple_clone->has_sharding()); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index d1d5dc17083..cd9a5d75885 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -141,9 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } } - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; + if (!ignore_dimensions_) { + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } } if (!ignore_layout_) { diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 0c9a2f3ab54..7eab897de9d 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -220,6 +220,10 @@ class Shape { ignore_dynamic_dimension_ = true; return *this; } + Equal& IgnoreDimensions() { + ignore_dimensions_ = true; + return *this; + } private: bool ignore_layout_ = false; @@ -229,6 +233,7 @@ class Shape { bool ignore_element_type_ = false; bool ignore_fp_precision_ = false; bool ignore_dynamic_dimension_ = false; + bool ignore_dimensions_ = false; }; // Test that all fields of the shape are the same, equivalent to Equal(). diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index e84a2591707..00e9eeac5ba 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -654,6 +654,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( .IgnoreLayout()(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringElementTypeAndDimensions( + const Shape& lhs, const Shape& rhs) { + return Shape::Equal().IgnoreElementType().IgnoreLayout().IgnoreDimensions() + .IgnoreDynamicDimension()(lhs, rhs); +} + /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { return Shape::Equal() diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ff47ab6ea80..ec9fbf50a90 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -293,6 +293,12 @@ class ShapeUtil { // compatibility. static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // Returns true if the tuple tree shapes are identical. Leaf dimensions, + // element type, and layout are ignored. Tuple elements are compared + // recursively for compatibility. + static bool CompatibleIgnoringElementTypeAndDimensions(const Shape& lhs, + const Shape& rhs); + // As Compatible, but allow one of lhs and rhs to be BF16 while the other // being F32. Tuple elements are compared recursively for compatibility. static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);