[XLA] Introduce asynchronous collective-permute (CollectivePermuteStart and CollectivePermuteDone) HLO opcodes.
PiperOrigin-RevId: 312794240 Change-Id: I0afa0ed1920fb97ac509ff2075559525265a28e2
This commit is contained in:
parent
987a095f85
commit
18ab11e146
@ -120,6 +120,8 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
|
||||
|
@ -110,6 +110,12 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleCollectivePermute(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleReplicaId(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
@ -736,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleCollectivePermuteStart(
|
||||
const HloInstruction* /*hlo*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleCollectivePermuteDone(
|
||||
const HloInstruction* /*hlo*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -80,6 +80,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleAllReduce(const HloInstruction* crs) override;
|
||||
Status HandleAllToAll(const HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermute(const HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermuteDone(const HloInstruction* hlo) override;
|
||||
Status HandleReplicaId(const HloInstruction* hlo) override;
|
||||
Status HandlePartitionId(const HloInstruction* hlo) override;
|
||||
Status HandleInfeed(const HloInstruction* infeed) override;
|
||||
|
@ -1061,6 +1061,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAllToAll:
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kPartitionId:
|
||||
|
@ -452,7 +452,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
/*channel_id=*/channel_id, split_dimension);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCollectivePermute: {
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteStart: {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs(
|
||||
proto.source_target_pairs_size());
|
||||
absl::optional<int64> channel_id;
|
||||
@ -463,8 +464,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
source_target_pairs[i].first = proto.source_target_pairs(i).source();
|
||||
source_target_pairs[i].second = proto.source_target_pairs(i).target();
|
||||
}
|
||||
instruction = CreateCollectivePermute(shape, operands(0),
|
||||
source_target_pairs, channel_id);
|
||||
|
||||
if (opcode == HloOpcode::kCollectivePermute) {
|
||||
instruction = CreateCollectivePermute(shape, operands(0),
|
||||
source_target_pairs, channel_id);
|
||||
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
|
||||
instruction = CreateCollectivePermuteStart(
|
||||
shape, operands(0), source_target_pairs, channel_id);
|
||||
} else {
|
||||
LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
|
||||
<< "but got " << HloOpcodeString(opcode);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReplicaId: {
|
||||
@ -805,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCopyStart:
|
||||
case HloOpcode::kCopyDone:
|
||||
@ -982,7 +993,18 @@ HloInstruction::CreateCollectivePermute(
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id) {
|
||||
return absl::make_unique<HloCollectivePermuteInstruction>(
|
||||
shape, operand, source_target_pairs, channel_id);
|
||||
HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
|
||||
channel_id);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateCollectivePermuteStart(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id) {
|
||||
return absl::make_unique<HloCollectivePermuteInstruction>(
|
||||
HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
|
||||
channel_id);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
|
||||
@ -1549,6 +1571,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAllToAll:
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kConvolution:
|
||||
@ -1575,6 +1598,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCopyStart:
|
||||
case HloOpcode::kCopyDone:
|
||||
@ -1928,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCopy:
|
||||
@ -2029,6 +2054,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAllToAll:
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kReduceWindow:
|
||||
@ -2888,6 +2914,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleAllToAll(this);
|
||||
case HloOpcode::kCollectivePermute:
|
||||
return visitor->HandleCollectivePermute(this);
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
return visitor->HandleCollectivePermuteStart(this);
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
return visitor->HandleCollectivePermuteDone(this);
|
||||
case HloOpcode::kReplicaId:
|
||||
return visitor->HandleReplicaId(this);
|
||||
case HloOpcode::kPartitionId:
|
||||
|
@ -681,7 +681,7 @@ class HloInstruction {
|
||||
const absl::optional<int64>& channel_id,
|
||||
const absl::optional<int64>& split_dimension = absl::nullopt);
|
||||
|
||||
// Creates a communication instructions that permutes data cross replicas.
|
||||
// Creates a communication instruction that permutes data cross replicas.
|
||||
// Data is sent/received according to the (source_replica_id,
|
||||
// target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
|
||||
// target_replica_id in any pair, the output on that replica is a tensor
|
||||
@ -691,6 +691,13 @@ class HloInstruction {
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// Creates a communication instruction that initiates the start of
|
||||
// CollectivePermute.
|
||||
static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// Creates an instruction that returns a U32 replica ID.
|
||||
static std::unique_ptr<HloInstruction> CreateReplicaId();
|
||||
|
||||
|
@ -703,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath(
|
||||
}
|
||||
|
||||
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id),
|
||||
: HloChannelInstruction(opcode, shape, channel_id),
|
||||
source_target_pairs_(source_target_pairs) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
@ -738,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
if (opcode() != other.opcode()) {
|
||||
return false;
|
||||
}
|
||||
const auto& casted_other =
|
||||
static_cast<const HloCollectivePermuteInstruction&>(other);
|
||||
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
@ -752,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
return absl::make_unique<HloCollectivePermuteInstruction>(
|
||||
shape, new_operands[0], source_target_pairs(), channel_id());
|
||||
opcode(), shape, new_operands[0], source_target_pairs(), channel_id());
|
||||
}
|
||||
|
||||
HloReverseInstruction::HloReverseInstruction(const Shape& shape,
|
||||
|
@ -463,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||
class HloCollectivePermuteInstruction : public HloChannelInstruction {
|
||||
public:
|
||||
explicit HloCollectivePermuteInstruction(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
|
@ -63,6 +63,8 @@ namespace xla {
|
||||
V(kCholesky, "cholesky", 1) \
|
||||
V(kClamp, "clamp", 3) \
|
||||
V(kCollectivePermute, "collective-permute", 1) \
|
||||
V(kCollectivePermuteStart, "collective-permute-start", 1) \
|
||||
V(kCollectivePermuteDone, "collective-permute-done", 1) \
|
||||
V(kClz, "count-leading-zeros", 1) \
|
||||
V(kCompare, "compare", 2) \
|
||||
V(kComplex, "complex", 2) \
|
||||
|
@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCopyStart:
|
||||
case HloOpcode::kCopyDone:
|
||||
@ -938,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
split_dimension));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCollectivePermute: {
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteStart: {
|
||||
optional<std::vector<std::vector<int64>>> source_targets;
|
||||
attrs["source_target_pairs"] = {
|
||||
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
|
||||
@ -957,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
pairs[i].first = (*source_targets)[i][0];
|
||||
pairs[i].second = (*source_targets)[i][1];
|
||||
}
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
|
||||
shape, operands[0], pairs, channel_id));
|
||||
if (opcode == HloOpcode::kCollectivePermute) {
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
|
||||
shape, operands[0], pairs, channel_id));
|
||||
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
|
||||
pairs, channel_id));
|
||||
} else {
|
||||
LOG(FATAL) << "Expect opcode to be CollectivePermute or "
|
||||
"CollectivePermuteStart, but got "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReplicaId: {
|
||||
|
@ -1553,6 +1553,20 @@ ENTRY CollectivePermute {
|
||||
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
}
|
||||
|
||||
)",
|
||||
/*replica_count=*/4
|
||||
},
|
||||
// collective-permute-start and -done
|
||||
{
|
||||
"CollectivePermuteStartAndDone",
|
||||
R"(HloModule CollectivePermuteStartAndDone
|
||||
|
||||
ENTRY CollectivePermuteStartAndDone {
|
||||
input = f32[128,32]{0,1} parameter(0)
|
||||
collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
|
||||
}
|
||||
|
||||
)",
|
||||
/*replica_count=*/4
|
||||
},
|
||||
|
@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
|
||||
@ -332,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
|
||||
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
namespace {
|
||||
|
||||
Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
|
||||
// A source or target cannot appear twice in the collective-permute's
|
||||
// source-target pairs.
|
||||
absl::flat_hash_set<int64> seen_sources;
|
||||
@ -351,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
p.second, hlo->ToString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
|
||||
return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
|
||||
hlo->operand(0)->shape()));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
|
||||
return CheckShape(
|
||||
hlo, ShapeUtil::MakeTupleShape(
|
||||
{hlo->operand(0)->shape(), hlo->operand(0)->shape(),
|
||||
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
|
||||
return CheckShape(
|
||||
hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
|
||||
reduce_precision->operand(0)->shape(),
|
||||
@ -1375,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks CopyStart and CopyDone nodes.
|
||||
Status VerifyAsynchronousCopies(const HloModule& module) {
|
||||
Status VerifySingleUser(const HloInstruction* instruction,
|
||||
HloOpcode expected_user) {
|
||||
TF_RET_CHECK(instruction->users().size() == 1)
|
||||
<< "The " << HloOpcodeString(instruction->opcode())
|
||||
<< " instruction requires one consumer, found "
|
||||
<< instruction->users().size();
|
||||
|
||||
const HloInstruction* user = instruction->users().front();
|
||||
TF_RET_CHECK(user->opcode() == expected_user)
|
||||
<< "The consumer of a " << HloOpcodeString(instruction->opcode())
|
||||
<< " instruction needs to be " << HloOpcodeString(expected_user)
|
||||
<< ", found " << HloOpcodeString(user->opcode());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifySingleOperand(const HloInstruction* instruction,
|
||||
HloOpcode expected_operand) {
|
||||
TF_RET_CHECK(instruction->operands().size() == 1)
|
||||
<< "The " << HloOpcodeString(instruction->opcode())
|
||||
<< " instruction requires one consumer, found "
|
||||
<< instruction->users().size();
|
||||
|
||||
const HloInstruction* operand = instruction->operand(0);
|
||||
TF_RET_CHECK(operand->opcode() == expected_operand)
|
||||
<< "The operand of a " << HloOpcodeString(instruction->opcode())
|
||||
<< " instruction needs to be " << HloOpcodeString(expected_operand)
|
||||
<< ", found " << HloOpcodeString(operand->opcode());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks asynchronous instruction pairs.
|
||||
Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
|
||||
// CopyStart must have a single CopyDone user.
|
||||
for (const HloComputation* computation : module.computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kCopyStart: {
|
||||
TF_RET_CHECK(instruction->users().size() == 1)
|
||||
<< "CopyStart instruction requires one consumer, found "
|
||||
<< instruction->users().size();
|
||||
const HloInstruction* copy_done = instruction->users().front();
|
||||
TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone)
|
||||
<< "The consumer of a CopyStart instruction needs to be "
|
||||
"CopyDone, found "
|
||||
<< HloOpcodeString(copy_done->opcode());
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifySingleUser(instruction, HloOpcode::kCopyDone));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCopyDone: {
|
||||
TF_RET_CHECK(instruction->operands().size() == 1)
|
||||
<< "CopyDone instruction requires one operand, found "
|
||||
<< instruction->operands().size();
|
||||
const HloInstruction* copy_start = instruction->operand(0);
|
||||
TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart)
|
||||
<< "The operand of a CopyDone instruction needs to be CopyStart, "
|
||||
"found "
|
||||
<< HloOpcodeString(copy_start->opcode());
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifySingleOperand(instruction, HloOpcode::kCopyStart));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCollectivePermuteStart: {
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCollectivePermuteDone: {
|
||||
TF_RETURN_IF_ERROR(VerifySingleOperand(
|
||||
instruction, HloOpcode::kCollectivePermuteStart));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -1815,7 +1864,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
|
||||
TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module));
|
||||
TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
|
||||
TF_RETURN_IF_ERROR(VerifyChannels(*module));
|
||||
|
||||
std::unique_ptr<ShapeVerifier> shape_verifier =
|
||||
|
@ -60,6 +60,8 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
Status HandleAllToAll(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermuteDone(HloInstruction* hlo) override;
|
||||
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||
Status HandleReplicaId(HloInstruction* hlo) override;
|
||||
Status HandleReducePrecision(HloInstruction* reduce_precision) override;
|
||||
|
@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
HasSubstr("CopyStart instruction requires one consumer, found 2"));
|
||||
HasSubstr("copy-start instruction requires one consumer, found 2"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
|
||||
@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("The operand of a CopyDone instruction needs to be "
|
||||
"CopyStart, found tuple"));
|
||||
HasSubstr("The operand of a copy-done instruction needs to be "
|
||||
"copy-start, found tuple"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, IotaNonArrayResult) {
|
||||
@ -1134,5 +1134,86 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
|
||||
HasSubstr("used for different types of channel instructions"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY CollectivePermuteStartAndDone {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
|
||||
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY CollectivePermuteStartAndDoneWrongType {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
|
||||
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected instruction to have shape equal to "
|
||||
"(f32[2,3], f32[2,3], u32[], u32[])"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY CollectivePermuteStartAndMultipleDone {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
|
||||
collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
|
||||
ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
HasSubstr("collective-permute-start instruction requires one consumer, "
|
||||
"found 2"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
p1 = f32[2,3]{1,0:S(1)} parameter(1)
|
||||
p2 = u32[] parameter(2)
|
||||
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
|
||||
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("The operand of a collective-permute-done instruction "
|
||||
"needs to be collective-permute-start, found tuple"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -149,6 +149,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAllToAll:
|
||||
case HloOpcode::kCollectivePermute:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kDomain:
|
||||
case HloOpcode::kDot:
|
||||
|
@ -2234,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kCopy:
|
||||
|
Loading…
x
Reference in New Issue
Block a user