Add optional layout constraint to AllToAll

PiperOrigin-RevId: 308740413
Change-Id: I48773db125a4c542d1c8c46c71100bda4a2b1108
This commit is contained in:
HyoukJoong Lee 2020-04-27 18:34:39 -07:00 committed by TensorFlower Gardener
parent c1d09d25d2
commit e2ccb9a5a5
13 changed files with 154 additions and 70 deletions

View File

@ -2307,7 +2307,8 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@ -2342,7 +2343,21 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
if (layout) {
TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
if (layout->minor_to_major().size() != shape.tuple_shapes(i).rank()) {
return InvalidArgument(
"Provided layout must be compatible with the operand shape: %s "
"vs %s",
layout->ToString(), operand_shape->ToString());
}
*(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
}
}
*instr.mutable_shape() = shape.ToProto();
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
@ -3466,9 +3481,10 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups);
split_count, replica_groups, layout);
}
XlaOp CollectivePermute(

View File

@ -557,7 +557,8 @@ class XlaBuilder {
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout = absl::nullopt);
XlaOp CollectivePermute(
XlaOp operand,
@ -993,7 +994,8 @@ class XlaBuilder {
const absl::optional<Shape>& shape_with_layout);
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout);
friend XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs);
@ -1789,9 +1791,13 @@ XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
// An optional `layout` can be specified to force the layout of the instruction.
// This is used to guarantee the same layout for a group of AllToAll ops
// compiled separately.
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {});
const std::vector<ReplicaGroup>& replica_groups = {},
const absl::optional<Layout>& layout = absl::nullopt);
// Enqueues an collective operation that sends and receives data cross replicas.
//

View File

@ -330,7 +330,8 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("shape_with_layout") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list());
py::arg("replica_groups") = py::list(),
py::arg("layout") = absl::nullopt);
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
py::arg("source_target_pairs"));
ops.def("CreateToken", &CreateToken, py::arg("builder"));

View File

@ -289,7 +289,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
replica_groups[0].add_replica_ids(1);
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a},
replica_groups, absl::nullopt));
replica_groups, /*constrain_layout=*/false, absl::nullopt));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
@ -318,7 +318,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
replica_groups[0].add_replica_ids(1);
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a},
replica_groups, absl::nullopt));
replica_groups, /*constrain_layout=*/false, absl::nullopt));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));

View File

@ -54,8 +54,9 @@ StatusOr<bool> HloDCE::RunOnComputation(
(remove_cross_partition_collective_ops &&
((instruction->opcode() == HloOpcode::kAllReduce &&
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
instruction->opcode() == HloOpcode::kCollectivePermute ||
instruction->opcode() == HloOpcode::kAllToAll)))) {
(instruction->opcode() == HloOpcode::kAllToAll &&
!Cast<HloAllToAllInstruction>(instruction)->constrain_layout()) ||
instruction->opcode() == HloOpcode::kCollectivePermute)))) {
dead_roots.push_back(instruction);
}
}

View File

@ -430,6 +430,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
/*constrain_layout=*/proto.constrain_layout(),
/*channel_id=*/channel_id, split_dimension);
break;
}
@ -939,11 +940,12 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension) {
return absl::make_unique<HloAllToAllInstruction>(
shape, operands, replica_groups, channel_id, split_dimension);
shape, operands, replica_groups, constrain_layout, channel_id,
split_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
@ -1375,6 +1377,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kAllReduce:
return channel_id().has_value() ||
Cast<HloAllReduceInstruction>(this)->constrain_layout();
case HloOpcode::kAllToAll:
return Cast<HloAllToAllInstruction>(this)->constrain_layout();
case HloOpcode::kCustomCall:
return Cast<HloCustomCallInstruction>(this)
->custom_call_has_side_effect();

View File

@ -667,7 +667,7 @@ class HloInstruction {
// It is used to implement the higher-level instruction in XlaBuilder.
static std::unique_ptr<HloInstruction> CreateAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension = absl::nullopt);

View File

@ -513,10 +513,11 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCollectiveInstruction::HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id)
: HloChannelInstruction(opcode, shape, channel_id),
replica_groups_(replica_groups) {
replica_groups_(replica_groups),
constrain_layout_(constrain_layout) {
for (auto operand : operands) {
AppendOperand(operand);
}
@ -526,6 +527,7 @@ HloInstructionProto HloCollectiveInstruction::ToProto() const {
HloInstructionProto proto = HloChannelInstruction::ToProto();
*proto.mutable_replica_groups() = {replica_groups_.begin(),
replica_groups_.end()};
proto.set_constrain_layout(constrain_layout_);
return proto;
}
@ -535,6 +537,9 @@ std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
HloChannelInstruction::ExtraAttributesToStringImpl(options);
result.push_back(
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
if (constrain_layout_) {
result.push_back("constrain_layout=true");
}
return result;
}
@ -557,8 +562,7 @@ HloAllReduceInstruction::HloAllReduceInstruction(
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id, bool use_global_device_ids)
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
replica_groups, channel_id),
constrain_layout_(constrain_layout),
replica_groups, constrain_layout, channel_id),
use_global_device_ids_(use_global_device_ids) {
AppendComputation(reduce_computation);
}
@ -574,7 +578,6 @@ bool HloAllReduceInstruction::IsNoop() const {
HloInstructionProto HloAllReduceInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
proto.set_constrain_layout(constrain_layout_);
proto.set_use_global_device_ids(use_global_device_ids_);
return proto;
}
@ -583,9 +586,6 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> result =
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (constrain_layout_) {
result.push_back("constrain_layout=true");
}
if (use_global_device_ids_) {
result.push_back("use_global_device_ids=true");
}
@ -614,11 +614,11 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
HloAllToAllInstruction::HloAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension)
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
replica_groups, channel_id),
replica_groups, constrain_layout, channel_id),
split_dimension_(split_dimension) {}
std::unique_ptr<HloInstruction>
@ -626,7 +626,8 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllToAllInstruction>(
shape, new_operands, replica_groups(), channel_id(), split_dimension());
shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
split_dimension());
}
HloInstructionProto HloAllToAllInstruction::ToProto() const {

View File

@ -313,37 +313,6 @@ class HloCollectiveInstruction : public HloChannelInstruction {
return replica_groups_;
}
protected:
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::vector<ReplicaGroup> replica_groups_;
};
class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id, bool use_global_device_ids);
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
// Returns true if the layout of the AllReduce is enforced by XLA client (as
// the layout set in the shape). The only reason for the client to set the
// layout is to separately compile computations that communicate with
@ -359,6 +328,38 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
// unconstrained AllReduce instructions (checked by HloVerifier).
bool constrain_layout() const { return constrain_layout_; }
protected:
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::vector<ReplicaGroup> replica_groups_;
bool constrain_layout_;
};
class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id, bool use_global_device_ids);
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
// Returns true if the ids in the ReplicaGroup config represent a global id of
// (replica_id * partition_count + partition_id) instead of a replica id.
// This enables more flexible grouping of devices if this all-reduce is both
@ -387,7 +388,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
bool constrain_layout_;
bool use_global_device_ids_;
};
@ -395,7 +395,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension);

View File

@ -887,6 +887,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
optional<bool> constrain_layout;
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
&constrain_layout};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
(dimensions && dimensions->size() != 1)) {
return false;
@ -900,7 +903,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
split_dimension = dimensions->at(0);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
shape, operands, replica_groups, channel_id, split_dimension));
shape, operands, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
split_dimension));
break;
}
case HloOpcode::kCollectivePermute: {

View File

@ -944,11 +944,6 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
std::vector<std::vector<int64>> replica_groups) {
const char* kTemplate = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p0 = f32[128]{0} parameter(0)
p1 = f32[128]{0} parameter(1)
@ -994,6 +989,24 @@ TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
HasSubstr("Replica group has size 1"));
}
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128,4]{0,1} parameter(0)
p1 = f32[128,4]{1,0} parameter(1)
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
replica_groups={{0,1}}
}
)";
HloModuleConfig config;
config.set_replica_count(2);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("HLO all-to-all has operands with different shapes"));
}
TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
const char* const kModuleStr = R"(
HloModule test

View File

@ -432,10 +432,10 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
return custom_call != nullptr && custom_call->layout_constrained();
}
bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) {
const HloAllReduceInstruction* all_reduce =
DynCast<HloAllReduceInstruction>(instruction);
return all_reduce != nullptr && all_reduce->constrain_layout();
bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
const HloCollectiveInstruction* collective =
DynCast<HloCollectiveInstruction>(instruction);
return collective != nullptr && collective->constrain_layout();
}
} // namespace
@ -520,7 +520,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
} else if (IsLayoutConstrainedAllReduce(instruction)) {
} else if (IsLayoutConstrainedCollective(instruction)) {
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(instruction->shape(), instruction));
} else if (instruction->IsCrossModuleAllReduce()) {
@ -1808,7 +1808,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// Some instructions carry mandatory layouts in their shape.
if (instruction->opcode() != HloOpcode::kInfeed &&
!IsLayoutConstrainedCustomCall(instruction) &&
!IsLayoutConstrainedAllReduce(instruction)) {
!IsLayoutConstrainedCollective(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}

View File

@ -1385,5 +1385,42 @@ ENTRY entry_computation {
ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
const char* module_str = R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry_computation {
param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
gte0 = f32[16,4] get-tuple-element(param), index=0
gte1 = f32[16,4] get-tuple-element(param), index=1
alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
replica_groups={{0,1}}, constrain_layout=true, to_apply=add
gte2 = f32[16,4] get-tuple-element(alltoall), index=0
gte3 = f32[16,4] get-tuple-element(alltoall), index=1
ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
}
} // namespace
} // namespace xla