[XLA] Introduce asynchronous collective-permute (CollectivePermuteStart and CollectivePermuteDone) HLO opcodes.

PiperOrigin-RevId: 312794240
Change-Id: I0afa0ed1920fb97ac509ff2075559525265a28e2
This commit is contained in:
Jinliang Wei 2020-05-21 21:44:14 -07:00 committed by TensorFlower Gardener
parent 987a095f85
commit 18ab11e146
17 changed files with 263 additions and 37 deletions

View File

@ -120,6 +120,8 @@ class DfsHloVisitorBase {
virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;

View File

@ -110,6 +110,12 @@ class DfsHloVisitorWithDefaultBase
Status HandleCollectivePermute(HloInstructionPtr hlo) override { Status HandleCollectivePermute(HloInstructionPtr hlo) override {
return DefaultAction(hlo); return DefaultAction(hlo);
} }
Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleReplicaId(HloInstructionPtr hlo) override { Status HandleReplicaId(HloInstructionPtr hlo) override {
return DefaultAction(hlo); return DefaultAction(hlo);
} }

View File

@ -736,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
return Status::OK(); return Status::OK();
} }
Status HloCostAnalysis::HandleCollectivePermuteStart(
const HloInstruction* /*hlo*/) {
return Status::OK();
}
Status HloCostAnalysis::HandleCollectivePermuteDone(
const HloInstruction* /*hlo*/) {
return Status::OK();
}
Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) { Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
return Status::OK(); return Status::OK();
} }

View File

@ -80,6 +80,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllReduce(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleAllToAll(const HloInstruction* hlo) override;
Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override;
Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
Status HandleCollectivePermuteDone(const HloInstruction* hlo) override;
Status HandleReplicaId(const HloInstruction* hlo) override; Status HandleReplicaId(const HloInstruction* hlo) override;
Status HandlePartitionId(const HloInstruction* hlo) override; Status HandlePartitionId(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override; Status HandleInfeed(const HloInstruction* infeed) override;

View File

@ -1061,6 +1061,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll: case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kPartitionId: case HloOpcode::kPartitionId:

View File

@ -452,7 +452,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/*channel_id=*/channel_id, split_dimension); /*channel_id=*/channel_id, split_dimension);
break; break;
} }
case HloOpcode::kCollectivePermute: { case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
std::vector<std::pair<int64, int64>> source_target_pairs( std::vector<std::pair<int64, int64>> source_target_pairs(
proto.source_target_pairs_size()); proto.source_target_pairs_size());
absl::optional<int64> channel_id; absl::optional<int64> channel_id;
@ -463,8 +464,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].first = proto.source_target_pairs(i).source();
source_target_pairs[i].second = proto.source_target_pairs(i).target(); source_target_pairs[i].second = proto.source_target_pairs(i).target();
} }
if (opcode == HloOpcode::kCollectivePermute) {
instruction = CreateCollectivePermute(shape, operands(0), instruction = CreateCollectivePermute(shape, operands(0),
source_target_pairs, channel_id); source_target_pairs, channel_id);
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = CreateCollectivePermuteStart(
shape, operands(0), source_target_pairs, channel_id);
} else {
LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
<< "but got " << HloOpcodeString(opcode);
}
break; break;
} }
case HloOpcode::kReplicaId: { case HloOpcode::kReplicaId: {
@ -805,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kRoundNearestAfz: case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
case HloOpcode::kCeil: case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy: case HloOpcode::kCopy:
case HloOpcode::kCopyStart: case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone: case HloOpcode::kCopyDone:
@ -982,7 +993,18 @@ HloInstruction::CreateCollectivePermute(
const std::vector<std::pair<int64, int64>>& source_target_pairs, const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id) { const absl::optional<int64>& channel_id) {
return absl::make_unique<HloCollectivePermuteInstruction>( return absl::make_unique<HloCollectivePermuteInstruction>(
shape, operand, source_target_pairs, channel_id); HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id) {
return absl::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
channel_id);
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
@ -1549,6 +1571,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll: case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kConvolution: case HloOpcode::kConvolution:
@ -1575,6 +1598,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
case HloOpcode::kCeil: case HloOpcode::kCeil:
case HloOpcode::kClz: case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy: case HloOpcode::kCopy:
case HloOpcode::kCopyStart: case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone: case HloOpcode::kCopyDone:
@ -1928,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCeil: case HloOpcode::kCeil:
case HloOpcode::kClamp: case HloOpcode::kClamp:
case HloOpcode::kClz: case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kComplex: case HloOpcode::kComplex:
case HloOpcode::kConvert: case HloOpcode::kConvert:
case HloOpcode::kCopy: case HloOpcode::kCopy:
@ -2029,6 +2054,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll: case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kConvolution: case HloOpcode::kConvolution:
case HloOpcode::kCustomCall: case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow: case HloOpcode::kReduceWindow:
@ -2888,6 +2914,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleAllToAll(this); return visitor->HandleAllToAll(this);
case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermute:
return visitor->HandleCollectivePermute(this); return visitor->HandleCollectivePermute(this);
case HloOpcode::kCollectivePermuteStart:
return visitor->HandleCollectivePermuteStart(this);
case HloOpcode::kCollectivePermuteDone:
return visitor->HandleCollectivePermuteDone(this);
case HloOpcode::kReplicaId: case HloOpcode::kReplicaId:
return visitor->HandleReplicaId(this); return visitor->HandleReplicaId(this);
case HloOpcode::kPartitionId: case HloOpcode::kPartitionId:

View File

@ -681,7 +681,7 @@ class HloInstruction {
const absl::optional<int64>& channel_id, const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension = absl::nullopt); const absl::optional<int64>& split_dimension = absl::nullopt);
// Creates a communication instructions that permutes data cross replicas. // Creates a communication instruction that permutes data cross replicas.
// Data is sent/received according to the (source_replica_id, // Data is sent/received according to the (source_replica_id,
// target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
// target_replica_id in any pair, the output on that replica is a tensor // target_replica_id in any pair, the output on that replica is a tensor
@ -691,6 +691,13 @@ class HloInstruction {
const std::vector<std::pair<int64, int64>>& source_target_pairs, const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id); const absl::optional<int64>& channel_id);
// Creates a communication instruction that initiates the start of
// CollectivePermute.
static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id);
// Creates an instruction that returns a U32 replica ID. // Creates an instruction that returns a U32 replica ID.
static std::unique_ptr<HloInstruction> CreateReplicaId(); static std::unique_ptr<HloInstruction> CreateReplicaId();

View File

@ -703,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath(
} }
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand, HloOpcode opcode, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs, const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id) const absl::optional<int64>& channel_id)
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id), : HloChannelInstruction(opcode, shape, channel_id),
source_target_pairs_(source_target_pairs) { source_target_pairs_(source_target_pairs) {
AppendOperand(operand); AppendOperand(operand);
} }
@ -738,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
const HloInstruction& other, const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>& const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const { eq_computations) const {
if (opcode() != other.opcode()) {
return false;
}
const auto& casted_other = const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other); static_cast<const HloCollectivePermuteInstruction&>(other);
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
@ -752,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands, const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const { HloCloneContext* /*context*/) const {
return absl::make_unique<HloCollectivePermuteInstruction>( return absl::make_unique<HloCollectivePermuteInstruction>(
shape, new_operands[0], source_target_pairs(), channel_id()); opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
} }
HloReverseInstruction::HloReverseInstruction(const Shape& shape, HloReverseInstruction::HloReverseInstruction(const Shape& shape,

View File

@ -463,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
class HloCollectivePermuteInstruction : public HloChannelInstruction { class HloCollectivePermuteInstruction : public HloChannelInstruction {
public: public:
explicit HloCollectivePermuteInstruction( explicit HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand, HloOpcode opcode, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs, const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id); const absl::optional<int64>& channel_id);

View File

@ -63,6 +63,8 @@ namespace xla {
V(kCholesky, "cholesky", 1) \ V(kCholesky, "cholesky", 1) \
V(kClamp, "clamp", 3) \ V(kClamp, "clamp", 3) \
V(kCollectivePermute, "collective-permute", 1) \ V(kCollectivePermute, "collective-permute", 1) \
V(kCollectivePermuteStart, "collective-permute-start", 1) \
V(kCollectivePermuteDone, "collective-permute-done", 1) \
V(kClz, "count-leading-zeros", 1) \ V(kClz, "count-leading-zeros", 1) \
V(kCompare, "compare", 2) \ V(kCompare, "compare", 2) \
V(kComplex, "complex", 2) \ V(kComplex, "complex", 2) \

View File

@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
case HloOpcode::kCeil: case HloOpcode::kCeil:
case HloOpcode::kClz: case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy: case HloOpcode::kCopy:
case HloOpcode::kCopyStart: case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone: case HloOpcode::kCopyDone:
@ -938,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
split_dimension)); split_dimension));
break; break;
} }
case HloOpcode::kCollectivePermute: { case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
optional<std::vector<std::vector<int64>>> source_targets; optional<std::vector<std::vector<int64>>> source_targets;
attrs["source_target_pairs"] = { attrs["source_target_pairs"] = {
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
@ -957,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
pairs[i].first = (*source_targets)[i][0]; pairs[i].first = (*source_targets)[i][0];
pairs[i].second = (*source_targets)[i][1]; pairs[i].second = (*source_targets)[i][1];
} }
if (opcode == HloOpcode::kCollectivePermute) {
instruction = instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute( builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id)); shape, operands[0], pairs, channel_id));
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = builder->AddInstruction(
HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
pairs, channel_id));
} else {
LOG(FATAL) << "Expect opcode to be CollectivePermute or "
"CollectivePermuteStart, but got "
<< HloOpcodeString(opcode);
}
break; break;
} }
case HloOpcode::kReplicaId: { case HloOpcode::kReplicaId: {

View File

@ -1553,6 +1553,20 @@ ENTRY CollectivePermute {
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
} }
)",
/*replica_count=*/4
},
// collective-permute-start and -done
{
"CollectivePermuteStartAndDone",
R"(HloModule CollectivePermuteStartAndDone
ENTRY CollectivePermuteStartAndDone {
input = f32[128,32]{0,1} parameter(0)
collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
}
)", )",
/*replica_count=*/4 /*replica_count=*/4
}, },

View File

@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction,
} }
return Status::OK(); return Status::OK();
} }
} // namespace } // namespace
Status ShapeVerifier::Preprocess(HloInstruction* hlo) { Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
@ -332,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
} }
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { namespace {
Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
// A source or target cannot appear twice in the collective-permute's // A source or target cannot appear twice in the collective-permute's
// source-target pairs. // source-target pairs.
absl::flat_hash_set<int64> seen_sources; absl::flat_hash_set<int64> seen_sources;
@ -351,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
p.second, hlo->ToString()); p.second, hlo->ToString());
} }
} }
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
hlo->operand(0)->shape())); hlo->operand(0)->shape()));
} }
Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(
hlo, ShapeUtil::MakeTupleShape(
{hlo->operand(0)->shape(), hlo->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
}
Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
return CheckShape(
hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
}
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(), reduce_precision->operand(0)->shape(),
@ -1375,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
return Status::OK(); return Status::OK();
} }
// Checks CopyStart and CopyDone nodes. Status VerifySingleUser(const HloInstruction* instruction,
Status VerifyAsynchronousCopies(const HloModule& module) { HloOpcode expected_user) {
TF_RET_CHECK(instruction->users().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* user = instruction->users().front();
TF_RET_CHECK(user->opcode() == expected_user)
<< "The consumer of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_user)
<< ", found " << HloOpcodeString(user->opcode());
return Status::OK();
}
Status VerifySingleOperand(const HloInstruction* instruction,
HloOpcode expected_operand) {
TF_RET_CHECK(instruction->operands().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* operand = instruction->operand(0);
TF_RET_CHECK(operand->opcode() == expected_operand)
<< "The operand of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_operand)
<< ", found " << HloOpcodeString(operand->opcode());
return Status::OK();
}
// Checks asynchronous instruction pairs.
Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
// CopyStart must have a single CopyDone user. // CopyStart must have a single CopyDone user.
for (const HloComputation* computation : module.computations()) { for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) { for (const HloInstruction* instruction : computation->instructions()) {
switch (instruction->opcode()) { switch (instruction->opcode()) {
case HloOpcode::kCopyStart: { case HloOpcode::kCopyStart: {
TF_RET_CHECK(instruction->users().size() == 1) TF_RETURN_IF_ERROR(
<< "CopyStart instruction requires one consumer, found " VerifySingleUser(instruction, HloOpcode::kCopyDone));
<< instruction->users().size();
const HloInstruction* copy_done = instruction->users().front();
TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone)
<< "The consumer of a CopyStart instruction needs to be "
"CopyDone, found "
<< HloOpcodeString(copy_done->opcode());
break; break;
} }
case HloOpcode::kCopyDone: { case HloOpcode::kCopyDone: {
TF_RET_CHECK(instruction->operands().size() == 1) TF_RETURN_IF_ERROR(
<< "CopyDone instruction requires one operand, found " VerifySingleOperand(instruction, HloOpcode::kCopyStart));
<< instruction->operands().size(); break;
const HloInstruction* copy_start = instruction->operand(0); }
TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart) case HloOpcode::kCollectivePermuteStart: {
<< "The operand of a CopyDone instruction needs to be CopyStart, " TF_RETURN_IF_ERROR(
"found " VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
<< HloOpcodeString(copy_start->opcode()); break;
}
case HloOpcode::kCollectivePermuteDone: {
TF_RETURN_IF_ERROR(VerifySingleOperand(
instruction, HloOpcode::kCollectivePermuteStart));
break; break;
} }
default: default:
@ -1815,7 +1864,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
} }
TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module)); TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
TF_RETURN_IF_ERROR(VerifyChannels(*module)); TF_RETURN_IF_ERROR(VerifyChannels(*module));
std::unique_ptr<ShapeVerifier> shape_verifier = std::unique_ptr<ShapeVerifier> shape_verifier =

View File

@ -60,6 +60,8 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override; Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
Status HandleCollectivePermuteDone(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override; Status HandlePartitionId(HloInstruction* hlo) override;
Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override;

View File

@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(
status.error_message(), status.error_message(),
HasSubstr("CopyStart instruction requires one consumer, found 2")); HasSubstr("copy-start instruction requires one consumer, found 2"));
} }
TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
auto status = verifier().Run(module.get()).status(); auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a CopyDone instruction needs to be " HasSubstr("The operand of a copy-done instruction needs to be "
"CopyStart, found tuple")); "copy-start, found tuple"));
} }
TEST_F(HloVerifierTest, IotaNonArrayResult) { TEST_F(HloVerifierTest, IotaNonArrayResult) {
@ -1134,5 +1134,86 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
HasSubstr("used for different types of channel instructions")); HasSubstr("used for different types of channel instructions"));
} }
TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDoneWrongType {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"(f32[2,3], f32[2,3], u32[], u32[])"));
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndMultipleDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("collective-permute-start instruction requires one consumer, "
"found 2"));
}
TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
p1 = f32[2,3]{1,0:S(1)} parameter(1)
p2 = u32[] parameter(2)
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a collective-permute-done instruction "
"needs to be collective-permute-start, found tuple"));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -149,6 +149,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll: case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCustomCall: case HloOpcode::kCustomCall:
case HloOpcode::kDomain: case HloOpcode::kDomain:
case HloOpcode::kDot: case HloOpcode::kDot:

View File

@ -2234,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
case HloOpcode::kBroadcast: case HloOpcode::kBroadcast:
case HloOpcode::kCall: case HloOpcode::kCall:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kConstant: case HloOpcode::kConstant:
case HloOpcode::kConvolution: case HloOpcode::kConvolution:
case HloOpcode::kCopy: case HloOpcode::kCopy: