[XLA] Introduce asynchronous collective-permute (CollectivePermuteStart and CollectivePermuteDone) HLO opcodes.

PiperOrigin-RevId: 312794240
Change-Id: I0afa0ed1920fb97ac509ff2075559525265a28e2
This commit is contained in:
Jinliang Wei 2020-05-21 21:44:14 -07:00 committed by TensorFlower Gardener
parent 987a095f85
commit 18ab11e146
17 changed files with 263 additions and 37 deletions

View File

@ -120,6 +120,8 @@ class DfsHloVisitorBase {
virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;

View File

@ -110,6 +110,12 @@ class DfsHloVisitorWithDefaultBase
Status HandleCollectivePermute(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleReplicaId(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}

View File

@ -736,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
return Status::OK();
}
Status HloCostAnalysis::HandleCollectivePermuteStart(
const HloInstruction* /*hlo*/) {
return Status::OK();
}
Status HloCostAnalysis::HandleCollectivePermuteDone(
const HloInstruction* /*hlo*/) {
return Status::OK();
}
Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
return Status::OK();
}

View File

@ -80,6 +80,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleAllReduce(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override;
Status HandleCollectivePermute(const HloInstruction* hlo) override;
Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
Status HandleCollectivePermuteDone(const HloInstruction* hlo) override;
Status HandleReplicaId(const HloInstruction* hlo) override;
Status HandlePartitionId(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;

View File

@ -1061,6 +1061,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kPartitionId:

View File

@ -452,7 +452,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/*channel_id=*/channel_id, split_dimension);
break;
}
case HloOpcode::kCollectivePermute: {
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
std::vector<std::pair<int64, int64>> source_target_pairs(
proto.source_target_pairs_size());
absl::optional<int64> channel_id;
@ -463,8 +464,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs[i].first = proto.source_target_pairs(i).source();
source_target_pairs[i].second = proto.source_target_pairs(i).target();
}
instruction = CreateCollectivePermute(shape, operands(0),
source_target_pairs, channel_id);
if (opcode == HloOpcode::kCollectivePermute) {
instruction = CreateCollectivePermute(shape, operands(0),
source_target_pairs, channel_id);
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = CreateCollectivePermuteStart(
shape, operands(0), source_target_pairs, channel_id);
} else {
LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
<< "but got " << HloOpcodeString(opcode);
}
break;
}
case HloOpcode::kReplicaId: {
@ -805,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
@ -982,7 +993,18 @@ HloInstruction::CreateCollectivePermute(
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id) {
return absl::make_unique<HloCollectivePermuteInstruction>(
shape, operand, source_target_pairs, channel_id);
HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id) {
return absl::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
@ -1549,6 +1571,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
@ -1575,6 +1598,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
@ -1928,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
@ -2029,6 +2054,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
@ -2888,6 +2914,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleAllToAll(this);
case HloOpcode::kCollectivePermute:
return visitor->HandleCollectivePermute(this);
case HloOpcode::kCollectivePermuteStart:
return visitor->HandleCollectivePermuteStart(this);
case HloOpcode::kCollectivePermuteDone:
return visitor->HandleCollectivePermuteDone(this);
case HloOpcode::kReplicaId:
return visitor->HandleReplicaId(this);
case HloOpcode::kPartitionId:

View File

@ -681,7 +681,7 @@ class HloInstruction {
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension = absl::nullopt);
// Creates a communication instructions that permutes data cross replicas.
// Creates a communication instruction that permutes data cross replicas.
// Data is sent/received according to the (source_replica_id,
// target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
// target_replica_id in any pair, the output on that replica is a tensor
@ -691,6 +691,13 @@ class HloInstruction {
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id);
// Creates a communication instruction that initiates the start of
// CollectivePermute.
static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id);
// Creates an instruction that returns a U32 replica ID.
static std::unique_ptr<HloInstruction> CreateReplicaId();

View File

@ -703,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath(
}
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand,
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id)
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id),
: HloChannelInstruction(opcode, shape, channel_id),
source_target_pairs_(source_target_pairs) {
AppendOperand(operand);
}
@ -738,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
if (opcode() != other.opcode()) {
return false;
}
const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other);
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
@ -752,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloCollectivePermuteInstruction>(
shape, new_operands[0], source_target_pairs(), channel_id());
opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
}
HloReverseInstruction::HloReverseInstruction(const Shape& shape,

View File

@ -463,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
class HloCollectivePermuteInstruction : public HloChannelInstruction {
public:
explicit HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand,
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id);

View File

@ -63,6 +63,8 @@ namespace xla {
V(kCholesky, "cholesky", 1) \
V(kClamp, "clamp", 3) \
V(kCollectivePermute, "collective-permute", 1) \
V(kCollectivePermuteStart, "collective-permute-start", 1) \
V(kCollectivePermuteDone, "collective-permute-done", 1) \
V(kClz, "count-leading-zeros", 1) \
V(kCompare, "compare", 2) \
V(kComplex, "complex", 2) \

View File

@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
@ -938,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
split_dimension));
break;
}
case HloOpcode::kCollectivePermute: {
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
optional<std::vector<std::vector<int64>>> source_targets;
attrs["source_target_pairs"] = {
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
@ -957,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
pairs[i].first = (*source_targets)[i][0];
pairs[i].second = (*source_targets)[i][1];
}
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id));
if (opcode == HloOpcode::kCollectivePermute) {
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id));
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = builder->AddInstruction(
HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
pairs, channel_id));
} else {
LOG(FATAL) << "Expect opcode to be CollectivePermute or "
"CollectivePermuteStart, but got "
<< HloOpcodeString(opcode);
}
break;
}
case HloOpcode::kReplicaId: {

View File

@ -1553,6 +1553,20 @@ ENTRY CollectivePermute {
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
}
)",
/*replica_count=*/4
},
// collective-permute-start and -done
{
"CollectivePermuteStartAndDone",
R"(HloModule CollectivePermuteStartAndDone
ENTRY CollectivePermuteStartAndDone {
input = f32[128,32]{0,1} parameter(0)
collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
}
)",
/*replica_count=*/4
},

View File

@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction,
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
@ -332,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
}
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
namespace {
Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
// A source or target cannot appear twice in the collective-permute's
// source-target pairs.
absl::flat_hash_set<int64> seen_sources;
@ -351,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
p.second, hlo->ToString());
}
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
hlo->operand(0)->shape()));
}
Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(
hlo, ShapeUtil::MakeTupleShape(
{hlo->operand(0)->shape(), hlo->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
}
Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
return CheckShape(
hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
}
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
@ -1375,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
return Status::OK();
}
// Checks CopyStart and CopyDone nodes.
Status VerifyAsynchronousCopies(const HloModule& module) {
Status VerifySingleUser(const HloInstruction* instruction,
HloOpcode expected_user) {
TF_RET_CHECK(instruction->users().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* user = instruction->users().front();
TF_RET_CHECK(user->opcode() == expected_user)
<< "The consumer of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_user)
<< ", found " << HloOpcodeString(user->opcode());
return Status::OK();
}
Status VerifySingleOperand(const HloInstruction* instruction,
HloOpcode expected_operand) {
TF_RET_CHECK(instruction->operands().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* operand = instruction->operand(0);
TF_RET_CHECK(operand->opcode() == expected_operand)
<< "The operand of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_operand)
<< ", found " << HloOpcodeString(operand->opcode());
return Status::OK();
}
// Checks asynchronous instruction pairs.
Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
// CopyStart must have a single CopyDone user.
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
switch (instruction->opcode()) {
case HloOpcode::kCopyStart: {
TF_RET_CHECK(instruction->users().size() == 1)
<< "CopyStart instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* copy_done = instruction->users().front();
TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone)
<< "The consumer of a CopyStart instruction needs to be "
"CopyDone, found "
<< HloOpcodeString(copy_done->opcode());
TF_RETURN_IF_ERROR(
VerifySingleUser(instruction, HloOpcode::kCopyDone));
break;
}
case HloOpcode::kCopyDone: {
TF_RET_CHECK(instruction->operands().size() == 1)
<< "CopyDone instruction requires one operand, found "
<< instruction->operands().size();
const HloInstruction* copy_start = instruction->operand(0);
TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart)
<< "The operand of a CopyDone instruction needs to be CopyStart, "
"found "
<< HloOpcodeString(copy_start->opcode());
TF_RETURN_IF_ERROR(
VerifySingleOperand(instruction, HloOpcode::kCopyStart));
break;
}
case HloOpcode::kCollectivePermuteStart: {
TF_RETURN_IF_ERROR(
VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
break;
}
case HloOpcode::kCollectivePermuteDone: {
TF_RETURN_IF_ERROR(VerifySingleOperand(
instruction, HloOpcode::kCollectivePermuteStart));
break;
}
default:
@ -1815,7 +1864,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
}
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module));
TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
TF_RETURN_IF_ERROR(VerifyChannels(*module));
std::unique_ptr<ShapeVerifier> shape_verifier =

View File

@ -60,6 +60,8 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
Status HandleCollectivePermuteDone(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override;
Status HandleReplicaId(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override;

View File

@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("CopyStart instruction requires one consumer, found 2"));
HasSubstr("copy-start instruction requires one consumer, found 2"));
}
TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a CopyDone instruction needs to be "
"CopyStart, found tuple"));
HasSubstr("The operand of a copy-done instruction needs to be "
"copy-start, found tuple"));
}
TEST_F(HloVerifierTest, IotaNonArrayResult) {
@ -1134,5 +1134,86 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
HasSubstr("used for different types of channel instructions"));
}
TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDoneWrongType {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"(f32[2,3], f32[2,3], u32[], u32[])"));
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndMultipleDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("collective-permute-start instruction requires one consumer, "
"found 2"));
}
TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
p1 = f32[2,3]{1,0:S(1)} parameter(1)
p2 = u32[] parameter(2)
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a collective-permute-done instruction "
"needs to be collective-permute-start, found tuple"));
}
} // namespace
} // namespace xla

View File

@ -149,6 +149,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:

View File

@ -2234,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopy: