Handle communicating instructions in HloComputation::ComputeReachability
Send&recv instructions and cross-replica-sum instructions are imposing extra dependencies via the channel id or all reduce id. This CL teaches the reachability calculation logic in hlo computation to correctly account for these "invisible" dependencies. The main purpose is to stop multi output fusion from generating dependency cyclies via communicating instructions. PiperOrigin-RevId: 209593997
This commit is contained in:
parent
ad4018ebb6
commit
aeab291563
@ -321,6 +321,7 @@ void ComputeComputationPostOrder(
|
|||||||
enum State { kVisiting, kVisited };
|
enum State { kVisiting, kVisited };
|
||||||
|
|
||||||
void ComputeInstructionPostOrder(
|
void ComputeInstructionPostOrder(
|
||||||
|
std::map<int64, std::vector<HloInstruction*>> channel_dependency_map,
|
||||||
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
|
tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
|
||||||
std::vector<HloInstruction*> dfs_stack;
|
std::vector<HloInstruction*> dfs_stack;
|
||||||
@ -355,12 +356,67 @@ void ComputeInstructionPostOrder(
|
|||||||
for (HloInstruction* op : current->control_predecessors()) {
|
for (HloInstruction* op : current->control_predecessors()) {
|
||||||
dfs_stack.emplace_back(op);
|
dfs_stack.emplace_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add inputs for send->recv_done dependencies and cross-replica-sum
|
||||||
|
// dependencies.
|
||||||
|
switch (current->opcode()) {
|
||||||
|
case HloOpcode::kRecvDone: {
|
||||||
|
const auto& dependencies =
|
||||||
|
channel_dependency_map[current->channel_id()];
|
||||||
|
for (HloInstruction* op : dependencies) {
|
||||||
|
dfs_stack.emplace_back(op);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case HloOpcode::kCrossReplicaSum: {
|
||||||
|
auto all_reduce_id = current->all_reduce_id();
|
||||||
|
if (all_reduce_id) {
|
||||||
|
const auto& dependencies =
|
||||||
|
channel_dependency_map[all_reduce_id.value()];
|
||||||
|
for (HloInstruction* op : dependencies) {
|
||||||
|
dfs_stack.emplace_back(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
std::map<int64, std::vector<HloInstruction*>>
|
||||||
|
HloComputation::ComputeChannelDependencies() const {
|
||||||
|
std::map<int64, std::vector<HloInstruction*>> channel_dependency_map;
|
||||||
|
for (const auto& instruction : instructions_) {
|
||||||
|
switch (instruction->opcode()) {
|
||||||
|
case HloOpcode::kSend: {
|
||||||
|
channel_dependency_map[instruction->channel_id()].push_back(
|
||||||
|
instruction.get());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case HloOpcode::kCrossReplicaSum: {
|
||||||
|
auto all_reduce_id = instruction->all_reduce_id();
|
||||||
|
if (all_reduce_id) {
|
||||||
|
auto& dependencies = channel_dependency_map[all_reduce_id.value()];
|
||||||
|
absl::c_copy(instruction->operands(),
|
||||||
|
std::back_inserter(dependencies));
|
||||||
|
absl::c_copy(instruction->control_predecessors(),
|
||||||
|
std::back_inserter(dependencies));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return channel_dependency_map;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
||||||
|
auto channel_dependency_map = ComputeChannelDependencies();
|
||||||
std::vector<HloInstruction*> post_order;
|
std::vector<HloInstruction*> post_order;
|
||||||
post_order.reserve(instruction_count());
|
post_order.reserve(instruction_count());
|
||||||
std::vector<HloInstruction*> trace_instructions;
|
std::vector<HloInstruction*> trace_instructions;
|
||||||
@ -372,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
|||||||
// users).
|
// users).
|
||||||
trace_instructions.push_back(instruction.get());
|
trace_instructions.push_back(instruction.get());
|
||||||
} else if (instruction->users().empty()) {
|
} else if (instruction->users().empty()) {
|
||||||
ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
|
ComputeInstructionPostOrder(channel_dependency_map, &post_order,
|
||||||
|
instruction.get(), &visited);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
post_order.insert(post_order.end(), trace_instructions.begin(),
|
post_order.insert(post_order.end(), trace_instructions.begin(),
|
||||||
@ -676,12 +733,33 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
|
|||||||
const {
|
const {
|
||||||
const auto& all = MakeInstructionPostOrder();
|
const auto& all = MakeInstructionPostOrder();
|
||||||
auto result = absl::make_unique<HloReachabilityMap>(all);
|
auto result = absl::make_unique<HloReachabilityMap>(all);
|
||||||
|
auto channel_dependency_map = ComputeChannelDependencies();
|
||||||
|
|
||||||
std::vector<HloInstruction*> inputs;
|
std::vector<HloInstruction*> inputs;
|
||||||
for (const HloInstruction* hlo : all) {
|
for (const HloInstruction* hlo : all) {
|
||||||
inputs.assign(hlo->operands().begin(), hlo->operands().end());
|
inputs.assign(hlo->operands().begin(), hlo->operands().end());
|
||||||
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
|
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
|
||||||
hlo->control_predecessors().end());
|
hlo->control_predecessors().end());
|
||||||
|
|
||||||
|
switch (hlo->opcode()) {
|
||||||
|
case HloOpcode::kRecvDone: {
|
||||||
|
const auto& dependencies = channel_dependency_map[hlo->channel_id()];
|
||||||
|
absl::c_copy(dependencies, std::back_inserter(inputs));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case HloOpcode::kCrossReplicaSum: {
|
||||||
|
auto all_reduce_id = hlo->all_reduce_id();
|
||||||
|
if (all_reduce_id) {
|
||||||
|
const auto& dependencies =
|
||||||
|
channel_dependency_map[all_reduce_id.value()];
|
||||||
|
absl::c_copy(dependencies, std::back_inserter(inputs));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
result->FastSetReachabilityToUnion(inputs, hlo);
|
result->FastSetReachabilityToUnion(inputs, hlo);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -399,6 +399,13 @@ class HloComputation {
|
|||||||
// Internal helper to collect unreachable roots.
|
// Internal helper to collect unreachable roots.
|
||||||
std::vector<HloInstruction*> CollectUnreachableRoots() const;
|
std::vector<HloInstruction*> CollectUnreachableRoots() const;
|
||||||
|
|
||||||
|
// Returns a map from channel-id to directed dependencies of the channel
|
||||||
|
// instructions. For send&recv pairs it means the send instruction and for
|
||||||
|
// cross-replica-sum the union of the dependencies for all participating
|
||||||
|
// instructions.
|
||||||
|
std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
|
||||||
|
const;
|
||||||
|
|
||||||
string name_;
|
string name_;
|
||||||
int64 unique_id_;
|
int64 unique_id_;
|
||||||
HloInstruction* root_instruction_;
|
HloInstruction* root_instruction_;
|
||||||
|
@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) {
|
|||||||
EXPECT_EQ(computation->ToString(options), expected_computation2);
|
EXPECT_EQ(computation->ToString(options), expected_computation2);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
TEST_F(HloComputationTest, ChannelReachability) {
|
||||||
|
const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
|
||||||
|
HloComputation::Builder builder("ChannelReachability");
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, shape, "param"));
|
||||||
|
auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
|
||||||
|
auto send =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
|
||||||
|
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
||||||
|
auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
|
||||||
|
auto recv =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
|
||||||
|
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
|
||||||
|
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build(recv_done));
|
||||||
|
auto reachability = computation->ComputeReachability();
|
||||||
|
EXPECT_TRUE(reachability->IsReachable(param, recv_done));
|
||||||
|
EXPECT_FALSE(reachability->IsReachable(send, recv));
|
||||||
|
EXPECT_FALSE(reachability->IsReachable(send_done, recv));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user