From 18ab11e1465f5f1ef6d323d29569f777dfea87f1 Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Thu, 21 May 2020 21:44:14 -0700 Subject: [PATCH] [XLA] Introduce asynchronous collective-permute (CollectivePermuteStart and CollectivePermuteDone) HLO opcodes. PiperOrigin-RevId: 312794240 Change-Id: I0afa0ed1920fb97ac509ff2075559525265a28e2 --- .../compiler/xla/service/dfs_hlo_visitor.h | 2 + .../service/dfs_hlo_visitor_with_default.h | 6 ++ .../compiler/xla/service/hlo_cost_analysis.cc | 10 ++ .../compiler/xla/service/hlo_cost_analysis.h | 2 + .../compiler/xla/service/hlo_graph_dumper.cc | 2 + .../compiler/xla/service/hlo_instruction.cc | 38 +++++++- .../compiler/xla/service/hlo_instruction.h | 9 +- .../compiler/xla/service/hlo_instructions.cc | 9 +- .../compiler/xla/service/hlo_instructions.h | 2 +- tensorflow/compiler/xla/service/hlo_opcode.h | 2 + tensorflow/compiler/xla/service/hlo_parser.cc | 20 +++- .../compiler/xla/service/hlo_parser_test.cc | 14 +++ .../compiler/xla/service/hlo_verifier.cc | 91 ++++++++++++++----- .../compiler/xla/service/hlo_verifier.h | 2 + .../compiler/xla/service/hlo_verifier_test.cc | 87 +++++++++++++++++- .../xla/service/instruction_fusion.cc | 2 + .../compiler/xla/service/layout_assignment.cc | 2 + 17 files changed, 263 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index caea9d9095a..bdaac32a0e5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -120,6 +120,8 @@ class DfsHloVisitorBase { virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 9cd220245ba..b1d674fe467 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -110,6 +110,12 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleReplicaId(HloInstructionPtr hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 32a9038b15a..50ba2077411 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -736,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleCollectivePermuteStart( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermuteDone( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9fdb42185fb..634a6c0572c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -80,6 +80,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; Status HandleReplicaId(const HloInstruction* hlo) override; Status HandlePartitionId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index cd2a61d7eff..3930898d665 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1061,6 +1061,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kPartitionId: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9e9c8b0913b..0aadd21d0a1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -452,7 +452,8 @@ StatusOr> HloInstruction::CreateFromProto( /*channel_id=*/channel_id, split_dimension); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { std::vector> source_target_pairs( proto.source_target_pairs_size()); absl::optional channel_id; @@ -463,8 +464,17 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(shape, operands(0), - source_target_pairs, channel_id); + + if (opcode == HloOpcode::kCollectivePermute) { + instruction = CreateCollectivePermute(shape, operands(0), + source_target_pairs, channel_id); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = CreateCollectivePermuteStart( + shape, operands(0), source_target_pairs, channel_id); + } else { + LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, " + << "but got " << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { @@ -805,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -982,7 +993,18 @@ HloInstruction::CreateCollectivePermute( const std::vector>& source_target_pairs, const absl::optional& channel_id) { return absl::make_unique( - shape, operand, source_target_pairs, channel_id); + HloOpcode::kCollectivePermute, shape, operand, source_target_pairs, + channel_id); +} + +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs, + const absl::optional& channel_id) { + return absl::make_unique( + HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs, + channel_id); } /* static */ std::unique_ptr HloInstruction::CreateReplicaId() { @@ -1549,6 +1571,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1575,6 +1598,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -1928,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCopy: @@ -2029,6 +2054,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -2888,6 +2914,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kCollectivePermuteStart: + return visitor->HandleCollectivePermuteStart(this); + case HloOpcode::kCollectivePermuteDone: + return visitor->HandleCollectivePermuteDone(this); case HloOpcode::kReplicaId: return visitor->HandleReplicaId(this); case HloOpcode::kPartitionId: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8be7a034877..c6cfda8e505 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -681,7 +681,7 @@ class HloInstruction { const absl::optional& channel_id, const absl::optional& split_dimension = absl::nullopt); - // Creates a communication instructions that permutes data cross replicas. + // Creates a communication instruction that permutes data cross replicas. // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor @@ -691,6 +691,13 @@ class HloInstruction { const std::vector>& source_target_pairs, const absl::optional& channel_id); + // Creates a communication instruction that initiates the start of + // CollectivePermute. + static std::unique_ptr CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs, + const absl::optional& channel_id); + // Creates an instruction that returns a U32 replica ID. static std::unique_ptr CreateReplicaId(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index d5bdd674563..e33d5960894 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -703,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath( } HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs, const absl::optional& channel_id) - : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id), + : HloChannelInstruction(opcode, shape, channel_id), source_target_pairs_(source_target_pairs) { AppendOperand(operand); } @@ -738,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const { + if (opcode() != other.opcode()) { + return false; + } const auto& casted_other = static_cast(other); return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && @@ -752,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands[0], source_target_pairs(), channel_id()); + opcode(), shape, new_operands[0], source_target_pairs(), channel_id()); } HloReverseInstruction::HloReverseInstruction(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ae78d365cfa..7f06c801e38 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -463,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { class HloCollectivePermuteInstruction : public HloChannelInstruction { public: explicit HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs, const absl::optional& channel_id); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 664fa10a990..92359bcbdac 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -63,6 +63,8 @@ namespace xla { V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ + V(kCollectivePermuteStart, "collective-permute-start", 1) \ + V(kCollectivePermuteDone, "collective-permute-done", 1) \ V(kClz, "count-leading-zeros", 1) \ V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f1908bcb996..d52a60d2555 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -938,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, split_dimension)); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { optional>> source_targets; attrs["source_target_pairs"] = { /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; @@ -957,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, pairs[i].first = (*source_targets)[i][0]; pairs[i].second = (*source_targets)[i][1]; } - instruction = - builder->AddInstruction(HloInstruction::CreateCollectivePermute( - shape, operands[0], pairs, channel_id)); + if (opcode == HloOpcode::kCollectivePermute) { + instruction = + builder->AddInstruction(HloInstruction::CreateCollectivePermute( + shape, operands[0], pairs, channel_id)); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermuteStart(shape, operands[0], + pairs, channel_id)); + } else { + LOG(FATAL) << "Expect opcode to be CollectivePermute or " + "CollectivePermuteStart, but got " + << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 8f63835b43d..a687d0e1921 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1553,6 +1553,20 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)", +/*replica_count=*/4 +}, +// collective-permute-start and -done +{ +"CollectivePermuteStartAndDone", +R"(HloModule CollectivePermuteStartAndDone + +ENTRY CollectivePermuteStartAndDone { + input = f32[128,32]{0,1} parameter(0) + collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}} + ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1) +} + )", /*replica_count=*/4 }, diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index d15a36532eb..4661b8fd9e3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction, } return Status::OK(); } - } // namespace Status ShapeVerifier::Preprocess(HloInstruction* hlo) { @@ -332,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); } -Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { +namespace { + +Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) { // A source or target cannot appear twice in the collective-permute's // source-target pairs. absl::flat_hash_set seen_sources; @@ -351,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { p.second, hlo->ToString()); } } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } +Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); + return CheckShape( + hlo, ShapeUtil::MakeTupleShape( + {hlo->operand(0)->shape(), hlo->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})})); +} + +Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) { + return CheckShape( + hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -1375,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, return Status::OK(); } -// Checks CopyStart and CopyDone nodes. -Status VerifyAsynchronousCopies(const HloModule& module) { +Status VerifySingleUser(const HloInstruction* instruction, + HloOpcode expected_user) { + TF_RET_CHECK(instruction->users().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* user = instruction->users().front(); + TF_RET_CHECK(user->opcode() == expected_user) + << "The consumer of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_user) + << ", found " << HloOpcodeString(user->opcode()); + return Status::OK(); +} + +Status VerifySingleOperand(const HloInstruction* instruction, + HloOpcode expected_operand) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* operand = instruction->operand(0); + TF_RET_CHECK(operand->opcode() == expected_operand) + << "The operand of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_operand) + << ", found " << HloOpcodeString(operand->opcode()); + return Status::OK(); +} + +// Checks asynchronous instruction pairs. +Status VerifyAsynchronousInstructionPairs(const HloModule& module) { // CopyStart must have a single CopyDone user. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCopyStart: { - TF_RET_CHECK(instruction->users().size() == 1) - << "CopyStart instruction requires one consumer, found " - << instruction->users().size(); - const HloInstruction* copy_done = instruction->users().front(); - TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone) - << "The consumer of a CopyStart instruction needs to be " - "CopyDone, found " - << HloOpcodeString(copy_done->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCopyDone)); break; } case HloOpcode::kCopyDone: { - TF_RET_CHECK(instruction->operands().size() == 1) - << "CopyDone instruction requires one operand, found " - << instruction->operands().size(); - const HloInstruction* copy_start = instruction->operand(0); - TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart) - << "The operand of a CopyDone instruction needs to be CopyStart, " - "found " - << HloOpcodeString(copy_start->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleOperand(instruction, HloOpcode::kCopyStart)); + break; + } + case HloOpcode::kCollectivePermuteStart: { + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone)); + break; + } + case HloOpcode::kCollectivePermuteDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, HloOpcode::kCollectivePermuteStart)); break; } default: @@ -1815,7 +1864,7 @@ StatusOr HloVerifier::Run(HloModule* module) { } TF_RETURN_IF_ERROR(VerifyHloStructure(module)); - TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module)); + TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); TF_RETURN_IF_ERROR(VerifyChannels(*module)); std::unique_ptr shape_verifier = diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7a2d3dc2e6c..85b02e0518c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -60,6 +60,8 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(HloInstruction* hlo) override; Status HandlePartitionId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index e2c363e40c5..294dfbf66fa 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.error_message(), - HasSubstr("CopyStart instruction requires one consumer, found 2")); + HasSubstr("copy-start instruction requires one consumer, found 2")); } TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { @@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), - HasSubstr("The operand of a CopyDone instruction needs to be " - "CopyStart, found tuple")); + HasSubstr("The operand of a copy-done instruction needs to be " + "copy-start, found tuple")); } TEST_F(HloVerifierTest, IotaNonArrayResult) { @@ -1134,5 +1134,86 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) { HasSubstr("used for different types of channel instructions")); } +TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDoneWrongType { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "(f32[2,3], f32[2,3], u32[], u32[])")); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndMultipleDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr("collective-permute-start instruction requires one consumer, " + "found 2")); +} + +TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteDoneNoCollectivePermuteStart { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + p1 = f32[2,3]{1,0:S(1)} parameter(1) + p2 = u32[] parameter(2) + tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2) + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("The operand of a collective-permute-done instruction " + "needs to be collective-permute-start, found tuple")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 5de081c6343..02966cc2bf2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -149,6 +149,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 13699f3adf9..82c30f1a710 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2234,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kCall: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: