diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index bae78c94bdc..70b18ff3567 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -321,6 +321,7 @@ void ComputeComputationPostOrder( enum State { kVisiting, kVisited }; void ComputeInstructionPostOrder( + std::map> channel_dependency_map, std::vector* post_order, HloInstruction* root, tensorflow::gtl::FlatMap* visited) { std::vector dfs_stack; @@ -355,12 +356,67 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = + channel_dependency_map[current->channel_id()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + } + break; + } + default: + break; + } } } } // namespace +std::map> +HloComputation::ComputeChannelDependencies() const { + std::map> channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} + std::vector HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; @@ -372,7 +428,8 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -676,12 +733,33 @@ std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); auto result = absl::make_unique(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = channel_dependency_map[hlo->channel_id()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 49ed65910f5..faa33f0f90e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -399,6 +399,13 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + std::map> ComputeChannelDependencies() + const; + string name_; int64 unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index e4c54703313..f7ed1b0316b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -} // namespace +TEST_F(HloComputationTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = computation->ComputeReachability(); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +} // namespace } // namespace xla