Fix layout assignment for cross module all reduce

Previously we could have ended up with the different HLOs being assigned
different layouts what made lowering impossible. This change enforces a
consistent layout between the communicating nodes the same way it is
done for send&recv pairs.

PiperOrigin-RevId: 215359420
This commit is contained in:
A. Unique TensorFlower 2018-10-02 03:01:09 -07:00 committed by TensorFlower Gardener
parent edea1be5dd
commit 44da41e490
2 changed files with 96 additions and 13 deletions

View File

@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
} else if (instruction->IsCrossModuleAllReduce()) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 all_reduce_id = instruction->all_reduce_id().value();
if (!get_channel_constraints(instruction)
->IsChannelConstrained(all_reduce_id)) {
continue;
}
// TODO(b/68493863): Change to use SetOperandLayout().
const Shape& buffer_shape = instruction->operand(0)->shape();
TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
Shape new_buffer_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(buffer_shape, all_reduce_id);
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
// Copy the root instruction's result if its layout does not match the result
// layout constraint.
if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape())) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
}
return Status::OK();
}
@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
// Copy the root instruction's result if its layout does not match the result
// layout constraint.
if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape())) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
}
return Status::OK();
}
@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
} else if (instruction->IsCrossModuleAllReduce()) {
const Layout* layout =
get_channel_constraints(instruction)
->ConstrainChannel(instruction->all_reduce_id().value(),
instruction->shape().layout());
if (layout != nullptr) {
// We found an already constrained layout which does not match the one
// the channel wants to impose. Either add a new kCopy, or use the
// existing one to marshal the correct shape.
HloInstruction* operand = instruction->mutable_operand(0);
Shape shape = operand->shape();
*shape.mutable_layout() = *layout;
if (operand->opcode() != HloOpcode::kCopy) {
HloInstruction* copy = operand->parent()->AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
RegisterAddedCopy(copy);
SetupCopiedInstruction(*operand, copy, {});
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
operand = copy;
} else {
*operand->mutable_shape() = shape;
}
*instruction->mutable_shape() = shape;
}
}
}
return Status::OK();

View File

@ -860,6 +860,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
// Pin non matching layouts to parameter and root.
const char* module_str = R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry_computation {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
ar.0 = f32[2,2] cross-replica-sum(gte),
all_reduce_id=0, replica_groups={{0}}, to_apply=add,
sharding={maximal device=0}
const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
ROOT ar.1 = f32[2,2] cross-replica-sum(const),
all_reduce_id=0, replica_groups={{0}}, to_apply=add,
sharding={maximal device=1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
AssignLayouts(module.get(), &computation_layout, &channel_constraints);
EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1));
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopySliceOperandToAvoidImplicitLayoutChange