Disabling the hoisting of converts in the presence of conditionals, where the branch roots are not of the same type.
PiperOrigin-RevId: 334705567 Change-Id: I5a94a89ae6dec40a5993defd1b290ae8bb0682a9
This commit is contained in:
parent
632fc11f67
commit
b9899df597
@ -351,125 +351,130 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
if (old_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
} else {
|
||||
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||
// Identify the gte using `index'.
|
||||
auto find_gte = [](const HloInstruction* conditional_result,
|
||||
int64 index) -> HloInstruction* {
|
||||
for (HloInstruction* instr : conditional_result->users()) {
|
||||
if (instr->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return nullptr;
|
||||
}
|
||||
if (instr->tuple_index() == index) {
|
||||
return instr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Captures tuple indices refering to converts to be rematerialized/hoisted.
|
||||
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
|
||||
old_root, branch_count, conditional, is_layout_sensitive);
|
||||
|
||||
// Exit if we cannot find any converts to be hoisted.
|
||||
if (kspecial_convert.empty()) {
|
||||
// Determining whether all branch roots are tuples
|
||||
for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
|
||||
HloInstruction* branch_root =
|
||||
conditional->branch_computation(branch_num)->root_instruction();
|
||||
if (branch_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestructureConditionalInstruction(conditional->parent(), conditional));
|
||||
|
||||
for (int branch = 0; branch < branch_count; branch++) {
|
||||
old_root = conditional->branch_computation(branch)->root_instruction();
|
||||
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
|
||||
std::vector<HloInstruction*> new_operands(old_root->operand_count());
|
||||
absl::flat_hash_set<HloInstruction*> to_hoist_set;
|
||||
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
|
||||
operand_num;
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||
// Identify the gte using `index'.
|
||||
auto find_gte = [](const HloInstruction* conditional_result,
|
||||
int64 index) -> HloInstruction* {
|
||||
for (HloInstruction* instr : conditional_result->users()) {
|
||||
if (instr->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* hoist = old_root->mutable_operand(operand_num);
|
||||
if (!kspecial_convert.contains(operand_num)) {
|
||||
new_operands[operand_num] = old_root->mutable_operand(operand_num);
|
||||
continue;
|
||||
}
|
||||
|
||||
to_hoist_set.insert(hoist);
|
||||
int64 new_tuple_count = old_root->operand_count();
|
||||
|
||||
// Replace the hoisted instr in the tuple with the operand/operands.
|
||||
// We will replace at least one of the operands of the hoist at the
|
||||
// tuple place; the rest will be added at the end.
|
||||
bool inplace = true;
|
||||
CHECK(!hoist->operands().empty());
|
||||
for (HloInstruction* prod : hoist->operands()) {
|
||||
if (inplace) {
|
||||
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
|
||||
new_operands[map_inst_to_tuple_index[hoist]] = prod;
|
||||
inplace = false;
|
||||
} else {
|
||||
map_inst_to_tuple_index[prod] = new_tuple_count++;
|
||||
new_operands.push_back(prod);
|
||||
}
|
||||
}
|
||||
if (instr->tuple_index() == index) {
|
||||
return instr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Create the new root instruction.
|
||||
HloComputation* cur_branch = conditional->branch_computation(branch);
|
||||
HloInstruction* new_branch_root =
|
||||
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
// The shape can vary since the operands to convert are now
|
||||
// being returned through the branches' root.
|
||||
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
|
||||
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
|
||||
// Captures tuple indices refering to converts to be rematerialized/hoisted.
|
||||
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
|
||||
old_root, branch_count, conditional, is_layout_sensitive);
|
||||
|
||||
// Only one of the branches needs to change the conditional->parent().
|
||||
if (branch != 0) {
|
||||
// Exit if we cannot find any converts to be hoisted.
|
||||
if (kspecial_convert.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestructureConditionalInstruction(conditional->parent(), conditional));
|
||||
|
||||
for (int branch = 0; branch < branch_count; branch++) {
|
||||
old_root = conditional->branch_computation(branch)->root_instruction();
|
||||
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
|
||||
std::vector<HloInstruction*> new_operands(old_root->operand_count());
|
||||
absl::flat_hash_set<HloInstruction*> to_hoist_set;
|
||||
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
|
||||
operand_num;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* hoist = old_root->mutable_operand(operand_num);
|
||||
if (!kspecial_convert.contains(operand_num)) {
|
||||
new_operands[operand_num] = old_root->mutable_operand(operand_num);
|
||||
continue;
|
||||
}
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
HloInstruction* newconditional =
|
||||
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
|
||||
cur_branch->root_instruction()->shape(),
|
||||
conditional->mutable_operand(0),
|
||||
absl::MakeSpan(conditional->branch_computations()),
|
||||
absl::MakeSpan(conditional->operands()).subspan(1)));
|
||||
// Ensure that all the users of conditional refer to the new one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
|
||||
conditional = newconditional;
|
||||
// Add the hoisted instructions in the parent.
|
||||
for (HloInstruction* hoist : to_hoist_set) {
|
||||
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
|
||||
int64 hoist_index = map_inst_to_tuple_index[hoist];
|
||||
// Find out the gte that captured the hoisted instr result.
|
||||
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
|
||||
CHECK(gte_hoist != nullptr);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* op : hoist->operands()) {
|
||||
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
op->shape(), conditional, map_inst_to_tuple_index[op]));
|
||||
new_operands.push_back(gte);
|
||||
|
||||
to_hoist_set.insert(hoist);
|
||||
int64 new_tuple_count = old_root->operand_count();
|
||||
|
||||
// Replace the hoisted instr in the tuple with the operand/operands.
|
||||
// We will replace at least one of the operands of the hoist at the
|
||||
// tuple place; the rest will be added at the end.
|
||||
bool inplace = true;
|
||||
CHECK(!hoist->operands().empty());
|
||||
for (HloInstruction* prod : hoist->operands()) {
|
||||
if (inplace) {
|
||||
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
|
||||
new_operands[map_inst_to_tuple_index[hoist]] = prod;
|
||||
inplace = false;
|
||||
} else {
|
||||
map_inst_to_tuple_index[prod] = new_tuple_count++;
|
||||
new_operands.push_back(prod);
|
||||
}
|
||||
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
|
||||
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
|
||||
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
|
||||
}
|
||||
// No need to explicitly delete a hoisted instruction since if its dead
|
||||
// then the subsequent DCE will remove it.
|
||||
}
|
||||
|
||||
// Create the new root instruction.
|
||||
HloComputation* cur_branch = conditional->branch_computation(branch);
|
||||
HloInstruction* new_branch_root =
|
||||
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
// The shape can vary since the operands to convert are now
|
||||
// being returned through the branches' root.
|
||||
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
|
||||
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
|
||||
|
||||
// Only one of the branches needs to change the conditional->parent().
|
||||
if (branch != 0) {
|
||||
continue;
|
||||
}
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
HloInstruction* newconditional =
|
||||
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
|
||||
cur_branch->root_instruction()->shape(),
|
||||
conditional->mutable_operand(0),
|
||||
absl::MakeSpan(conditional->branch_computations()),
|
||||
absl::MakeSpan(conditional->operands()).subspan(1)));
|
||||
// Ensure that all the users of conditional refer to the new one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
|
||||
conditional = newconditional;
|
||||
// Add the hoisted instructions in the parent.
|
||||
for (HloInstruction* hoist : to_hoist_set) {
|
||||
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
|
||||
int64 hoist_index = map_inst_to_tuple_index[hoist];
|
||||
// Find out the gte that captured the hoisted instr result.
|
||||
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
|
||||
CHECK(gte_hoist != nullptr);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* op : hoist->operands()) {
|
||||
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
|
||||
map_inst_to_tuple_index[op]));
|
||||
new_operands.push_back(gte);
|
||||
}
|
||||
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
|
||||
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
|
||||
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
|
||||
}
|
||||
// No need to explicitly delete a hoisted instruction since if its dead
|
||||
// then the subsequent DCE will remove it.
|
||||
}
|
||||
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
|
||||
return true;
|
||||
|
||||
@ -78,6 +78,52 @@ ENTRY main {
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
body {
|
||||
%p_body = (f32[2], bf16[2], s32[]) parameter(0)
|
||||
%val = f32[2] get-tuple-element(p_body), index=0
|
||||
%val2 = bf16[2] get-tuple-element(p_body), index=1
|
||||
%const = s32[] constant(-1)
|
||||
ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const)
|
||||
}
|
||||
|
||||
condition {
|
||||
%p_cond = (f32[2], bf16[2], s32[]) parameter(0)
|
||||
%gte = s32[] get-tuple-element(%p_cond), index=2
|
||||
%const = s32[] constant(42)
|
||||
ROOT result = pred[] compare(%gte, %const), direction=EQ
|
||||
}
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = f32[2] parameter(0)
|
||||
%const = s32[] constant(42)
|
||||
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
|
||||
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
|
||||
ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
|
||||
}
|
||||
on_false {
|
||||
%arg_tuple.1 = f32[2] parameter(0)
|
||||
%const = s32[] constant(42)
|
||||
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
|
||||
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
|
||||
%while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
|
||||
ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body
|
||||
}
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = f32[2] parameter(1)
|
||||
ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user