Add optional layout constraint to AllToAll
PiperOrigin-RevId: 308740413 Change-Id: I48773db125a4c542d1c8c46c71100bda4a2b1108
This commit is contained in:
parent
c1d09d25d2
commit
e2ccb9a5a5
@ -2307,7 +2307,8 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
|
||||
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups) {
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
|
||||
@ -2342,7 +2343,21 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
|
||||
|
||||
if (layout) {
|
||||
TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
if (layout->minor_to_major().size() != shape.tuple_shapes(i).rank()) {
|
||||
return InvalidArgument(
|
||||
"Provided layout must be compatible with the operand shape: %s "
|
||||
"vs %s",
|
||||
layout->ToString(), operand_shape->ToString());
|
||||
}
|
||||
*(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
|
||||
}
|
||||
}
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
}
|
||||
@ -3466,9 +3481,10 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
|
||||
|
||||
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups) {
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout) {
|
||||
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
|
||||
split_count, replica_groups);
|
||||
split_count, replica_groups, layout);
|
||||
}
|
||||
|
||||
XlaOp CollectivePermute(
|
||||
|
@ -557,7 +557,8 @@ class XlaBuilder {
|
||||
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
|
||||
XlaOp CollectivePermute(
|
||||
XlaOp operand,
|
||||
@ -993,7 +994,8 @@ class XlaBuilder {
|
||||
const absl::optional<Shape>& shape_with_layout);
|
||||
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout);
|
||||
friend XlaOp CollectivePermute(
|
||||
XlaOp operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||
@ -1789,9 +1791,13 @@ XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
||||
// An optional `layout` can be specified to force the layout of the instruction.
|
||||
// This is used to guarantee the same layout for a group of AllToAll ops
|
||||
// compiled separately.
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups = {});
|
||||
const std::vector<ReplicaGroup>& replica_groups = {},
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
|
||||
// Enqueues an collective operation that sends and receives data cross replicas.
|
||||
//
|
||||
|
@ -330,7 +330,8 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
py::arg("shape_with_layout") = absl::nullopt);
|
||||
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
|
||||
py::arg("concat_dimension"), py::arg("split_count"),
|
||||
py::arg("replica_groups") = py::list());
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("layout") = absl::nullopt);
|
||||
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
|
||||
py::arg("source_target_pairs"));
|
||||
ops.def("CreateToken", &CreateToken, py::arg("builder"));
|
||||
|
@ -289,7 +289,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
||||
replica_groups[0].add_replica_ids(1);
|
||||
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
|
||||
ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a},
|
||||
replica_groups, absl::nullopt));
|
||||
replica_groups, /*constrain_layout=*/false, absl::nullopt));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
@ -318,7 +318,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
|
||||
replica_groups[0].add_replica_ids(1);
|
||||
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
|
||||
ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a},
|
||||
replica_groups, absl::nullopt));
|
||||
replica_groups, /*constrain_layout=*/false, absl::nullopt));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
@ -54,8 +54,9 @@ StatusOr<bool> HloDCE::RunOnComputation(
|
||||
(remove_cross_partition_collective_ops &&
|
||||
((instruction->opcode() == HloOpcode::kAllReduce &&
|
||||
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
|
||||
instruction->opcode() == HloOpcode::kCollectivePermute ||
|
||||
instruction->opcode() == HloOpcode::kAllToAll)))) {
|
||||
(instruction->opcode() == HloOpcode::kAllToAll &&
|
||||
!Cast<HloAllToAllInstruction>(instruction)->constrain_layout()) ||
|
||||
instruction->opcode() == HloOpcode::kCollectivePermute)))) {
|
||||
dead_roots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
|
@ -430,6 +430,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
/*replica_groups=*/
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()),
|
||||
/*constrain_layout=*/proto.constrain_layout(),
|
||||
/*channel_id=*/channel_id, split_dimension);
|
||||
break;
|
||||
}
|
||||
@ -939,11 +940,12 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id,
|
||||
const absl::optional<int64>& split_dimension) {
|
||||
return absl::make_unique<HloAllToAllInstruction>(
|
||||
shape, operands, replica_groups, channel_id, split_dimension);
|
||||
shape, operands, replica_groups, constrain_layout, channel_id,
|
||||
split_dimension);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
@ -1375,6 +1377,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
|
||||
case HloOpcode::kAllReduce:
|
||||
return channel_id().has_value() ||
|
||||
Cast<HloAllReduceInstruction>(this)->constrain_layout();
|
||||
case HloOpcode::kAllToAll:
|
||||
return Cast<HloAllToAllInstruction>(this)->constrain_layout();
|
||||
case HloOpcode::kCustomCall:
|
||||
return Cast<HloCustomCallInstruction>(this)
|
||||
->custom_call_has_side_effect();
|
||||
|
@ -667,7 +667,7 @@ class HloInstruction {
|
||||
// It is used to implement the higher-level instruction in XlaBuilder.
|
||||
static std::unique_ptr<HloInstruction> CreateAllToAll(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id,
|
||||
const absl::optional<int64>& split_dimension = absl::nullopt);
|
||||
|
||||
|
@ -513,10 +513,11 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
|
||||
HloCollectiveInstruction::HloCollectiveInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloChannelInstruction(opcode, shape, channel_id),
|
||||
replica_groups_(replica_groups) {
|
||||
replica_groups_(replica_groups),
|
||||
constrain_layout_(constrain_layout) {
|
||||
for (auto operand : operands) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
@ -526,6 +527,7 @@ HloInstructionProto HloCollectiveInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||
*proto.mutable_replica_groups() = {replica_groups_.begin(),
|
||||
replica_groups_.end()};
|
||||
proto.set_constrain_layout(constrain_layout_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -535,6 +537,9 @@ std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
|
||||
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||
result.push_back(
|
||||
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
|
||||
if (constrain_layout_) {
|
||||
result.push_back("constrain_layout=true");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -557,8 +562,7 @@ HloAllReduceInstruction::HloAllReduceInstruction(
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
|
||||
replica_groups, channel_id),
|
||||
constrain_layout_(constrain_layout),
|
||||
replica_groups, constrain_layout, channel_id),
|
||||
use_global_device_ids_(use_global_device_ids) {
|
||||
AppendComputation(reduce_computation);
|
||||
}
|
||||
@ -574,7 +578,6 @@ bool HloAllReduceInstruction::IsNoop() const {
|
||||
|
||||
HloInstructionProto HloAllReduceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
||||
proto.set_constrain_layout(constrain_layout_);
|
||||
proto.set_use_global_device_ids(use_global_device_ids_);
|
||||
return proto;
|
||||
}
|
||||
@ -583,9 +586,6 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result =
|
||||
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
|
||||
if (constrain_layout_) {
|
||||
result.push_back("constrain_layout=true");
|
||||
}
|
||||
if (use_global_device_ids_) {
|
||||
result.push_back("use_global_device_ids=true");
|
||||
}
|
||||
@ -614,11 +614,11 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
|
||||
|
||||
HloAllToAllInstruction::HloAllToAllInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id,
|
||||
const absl::optional<int64>& split_dimension)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
|
||||
replica_groups, channel_id),
|
||||
replica_groups, constrain_layout, channel_id),
|
||||
split_dimension_(split_dimension) {}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
@ -626,7 +626,8 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
return absl::make_unique<HloAllToAllInstruction>(
|
||||
shape, new_operands, replica_groups(), channel_id(), split_dimension());
|
||||
shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
|
||||
split_dimension());
|
||||
}
|
||||
|
||||
HloInstructionProto HloAllToAllInstruction::ToProto() const {
|
||||
|
@ -313,37 +313,6 @@ class HloCollectiveInstruction : public HloChannelInstruction {
|
||||
return replica_groups_;
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit HloCollectiveInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
|
||||
std::vector<ReplicaGroup> replica_groups_;
|
||||
};
|
||||
|
||||
class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
public:
|
||||
explicit HloAllReduceInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
||||
|
||||
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||
// to a mem copy.
|
||||
bool IsNoop() const;
|
||||
|
||||
// Returns true if the layout of the AllReduce is enforced by XLA client (as
|
||||
// the layout set in the shape). The only reason for the client to set the
|
||||
// layout is to separately compile computations that communicate with
|
||||
@ -359,6 +328,38 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
// unconstrained AllReduce instructions (checked by HloVerifier).
|
||||
bool constrain_layout() const { return constrain_layout_; }
|
||||
|
||||
protected:
|
||||
explicit HloCollectiveInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
|
||||
std::vector<ReplicaGroup> replica_groups_;
|
||||
bool constrain_layout_;
|
||||
};
|
||||
|
||||
class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
public:
|
||||
explicit HloAllReduceInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
||||
|
||||
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||
// to a mem copy.
|
||||
bool IsNoop() const;
|
||||
|
||||
// Returns true if the ids in the ReplicaGroup config represent a global id of
|
||||
// (replica_id * partition_count + partition_id) instead of a replica id.
|
||||
// This enables more flexible grouping of devices if this all-reduce is both
|
||||
@ -387,7 +388,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
bool constrain_layout_;
|
||||
bool use_global_device_ids_;
|
||||
};
|
||||
|
||||
@ -395,7 +395,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||
public:
|
||||
explicit HloAllToAllInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id,
|
||||
const absl::optional<int64>& split_dimension);
|
||||
|
||||
|
@ -887,6 +887,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<std::vector<int64>> dimensions;
|
||||
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
|
||||
&dimensions};
|
||||
optional<bool> constrain_layout;
|
||||
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
|
||||
&constrain_layout};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
|
||||
(dimensions && dimensions->size() != 1)) {
|
||||
return false;
|
||||
@ -900,7 +903,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
split_dimension = dimensions->at(0);
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
|
||||
shape, operands, replica_groups, channel_id, split_dimension));
|
||||
shape, operands, replica_groups,
|
||||
constrain_layout ? *constrain_layout : false, channel_id,
|
||||
split_dimension));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCollectivePermute: {
|
||||
|
@ -944,11 +944,6 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
||||
std::vector<std::vector<int64>> replica_groups) {
|
||||
const char* kTemplate = R"(
|
||||
HloModule test
|
||||
add {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT add = f32[] add(x, y)
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f32[128]{0} parameter(0)
|
||||
p1 = f32[128]{0} parameter(1)
|
||||
@ -994,6 +989,24 @@ TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
|
||||
HasSubstr("Replica group has size 1"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY entry {
|
||||
p0 = f32[128,4]{0,1} parameter(0)
|
||||
p1 = f32[128,4]{1,0} parameter(1)
|
||||
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
|
||||
replica_groups={{0,1}}
|
||||
}
|
||||
)";
|
||||
HloModuleConfig config;
|
||||
config.set_replica_count(2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr, config));
|
||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||
HasSubstr("HLO all-to-all has operands with different shapes"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
|
@ -432,10 +432,10 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
|
||||
return custom_call != nullptr && custom_call->layout_constrained();
|
||||
}
|
||||
|
||||
bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) {
|
||||
const HloAllReduceInstruction* all_reduce =
|
||||
DynCast<HloAllReduceInstruction>(instruction);
|
||||
return all_reduce != nullptr && all_reduce->constrain_layout();
|
||||
bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
|
||||
const HloCollectiveInstruction* collective =
|
||||
DynCast<HloCollectiveInstruction>(instruction);
|
||||
return collective != nullptr && collective->constrain_layout();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -520,7 +520,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
||||
}
|
||||
} else if (IsLayoutConstrainedAllReduce(instruction)) {
|
||||
} else if (IsLayoutConstrainedCollective(instruction)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(instruction->shape(), instruction));
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
@ -1808,7 +1808,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
|
||||
// Some instructions carry mandatory layouts in their shape.
|
||||
if (instruction->opcode() != HloOpcode::kInfeed &&
|
||||
!IsLayoutConstrainedCustomCall(instruction) &&
|
||||
!IsLayoutConstrainedAllReduce(instruction)) {
|
||||
!IsLayoutConstrainedCollective(instruction)) {
|
||||
LayoutUtil::ClearLayout(instruction->mutable_shape());
|
||||
}
|
||||
}
|
||||
|
@ -1385,5 +1385,42 @@ ENTRY entry_computation {
|
||||
ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
|
||||
const char* module_str = R"(
|
||||
HloModule test_module
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
|
||||
gte0 = f32[16,4] get-tuple-element(param), index=0
|
||||
gte1 = f32[16,4] get-tuple-element(param), index=1
|
||||
alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
|
||||
replica_groups={{0,1}}, constrain_layout=true, to_apply=add
|
||||
gte2 = f32[16,4] get-tuple-element(alltoall), index=0
|
||||
gte3 = f32[16,4] get-tuple-element(alltoall), index=1
|
||||
ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> m,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ComputationLayout computation_layout(
|
||||
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
|
||||
|
||||
ChannelLayoutConstraints channel_constraints;
|
||||
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
|
||||
|
||||
const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
|
||||
ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
|
||||
ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
|
||||
ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user