Fix layout assignment for cross module all reduce
Previously we could have ended up with the different HLOs being assigned different layouts what made lowering impossible. This change enforces a consistent layout between the communicating nodes the same way it is done for send&recv pairs. PiperOrigin-RevId: 215359420
This commit is contained in:
parent
edea1be5dd
commit
44da41e490
@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
||||
}
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
CHECK(get_channel_constraints(instruction))
|
||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||
int64 all_reduce_id = instruction->all_reduce_id().value();
|
||||
if (!get_channel_constraints(instruction)
|
||||
->IsChannelConstrained(all_reduce_id)) {
|
||||
continue;
|
||||
}
|
||||
// TODO(b/68493863): Change to use SetOperandLayout().
|
||||
const Shape& buffer_shape = instruction->operand(0)->shape();
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
|
||||
Shape new_buffer_shape =
|
||||
get_channel_constraints(instruction)
|
||||
->LayoutShapeForChannel(buffer_shape, all_reduce_id);
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(new_buffer_shape, instruction));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
||||
// Verify all layouts in the shape have been set.
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
|
||||
}
|
||||
|
||||
// Copy the root instruction's result if its layout does not match the result
|
||||
// layout constraint.
|
||||
if (constraints.ResultLayout() != nullptr &&
|
||||
!constraints.ResultLayout()->MatchesLayoutInShape(
|
||||
computation->root_instruction()->shape())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_root,
|
||||
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
|
||||
computation->root_instruction()));
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConstrainChannelLayouts(computation, channel_constraints));
|
||||
}
|
||||
|
||||
// Copy the root instruction's result if its layout does not match the result
|
||||
// layout constraint.
|
||||
if (constraints.ResultLayout() != nullptr &&
|
||||
!constraints.ResultLayout()->MatchesLayoutInShape(
|
||||
computation->root_instruction()->shape())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_root,
|
||||
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
|
||||
computation->root_instruction()));
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
||||
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
|
||||
*send_shape = shape;
|
||||
}
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
const Layout* layout =
|
||||
get_channel_constraints(instruction)
|
||||
->ConstrainChannel(instruction->all_reduce_id().value(),
|
||||
instruction->shape().layout());
|
||||
if (layout != nullptr) {
|
||||
// We found an already constrained layout which does not match the one
|
||||
// the channel wants to impose. Either add a new kCopy, or use the
|
||||
// existing one to marshal the correct shape.
|
||||
HloInstruction* operand = instruction->mutable_operand(0);
|
||||
Shape shape = operand->shape();
|
||||
*shape.mutable_layout() = *layout;
|
||||
if (operand->opcode() != HloOpcode::kCopy) {
|
||||
HloInstruction* copy = operand->parent()->AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
|
||||
RegisterAddedCopy(copy);
|
||||
SetupCopiedInstruction(*operand, copy, {});
|
||||
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
|
||||
operand = copy;
|
||||
} else {
|
||||
*operand->mutable_shape() = shape;
|
||||
}
|
||||
*instruction->mutable_shape() = shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -860,6 +860,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
|
||||
// Pin non matching layouts to parameter and root.
|
||||
const char* module_str = R"(
|
||||
HloModule test_module
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
param = (f32[2,2]) parameter(0)
|
||||
gte = f32[2,2] get-tuple-element(param), index=0
|
||||
ar.0 = f32[2,2] cross-replica-sum(gte),
|
||||
all_reduce_id=0, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=0}
|
||||
const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
|
||||
ROOT ar.1 = f32[2,2] cross-replica-sum(const),
|
||||
all_reduce_id=0, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ComputationLayout computation_layout(
|
||||
module->entry_computation()->ComputeProgramShape());
|
||||
Shape param_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
|
||||
TF_ASSERT_OK(
|
||||
computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
|
||||
param_shape));
|
||||
computation_layout.mutable_result_layout()->ResetLayout(
|
||||
LayoutUtil::MakeLayout({1, 0}));
|
||||
|
||||
ChannelLayoutConstraints channel_constraints;
|
||||
AssignLayouts(module.get(), &computation_layout, &channel_constraints);
|
||||
|
||||
EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
|
||||
EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1));
|
||||
EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1));
|
||||
const HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
|
||||
const char* module_str = R"(
|
||||
HloModule CopySliceOperandToAvoidImplicitLayoutChange
|
||||
|
Loading…
Reference in New Issue
Block a user