diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 082bf8bffed..25d53275611 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 752a61476dd..10f9a951212 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -860,6 +860,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + ar.0 = f32[2,2] cross-replica-sum(gte), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=0} + const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] cross-replica-sum(const), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=1} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); +} + TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { const char* module_str = R"( HloModule CopySliceOperandToAvoidImplicitLayoutChange