Disabling the hoisting of converts in the presence of conditionals, where the branch roots are not of the same type.

PiperOrigin-RevId: 334705567
Change-Id: I5a94a89ae6dec40a5993defd1b290ae8bb0682a9
This commit is contained in:
A. Unique TensorFlower 2020-09-30 16:41:07 -07:00 committed by TensorFlower Gardener
parent 632fc11f67
commit b9899df597
2 changed files with 158 additions and 107 deletions

View File

@ -351,125 +351,130 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
return false;
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
if (old_root->opcode() != HloOpcode::kTuple) {
return false;
} else {
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
// Determining whether all branch roots are tuples
for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
HloInstruction* branch_root =
conditional->branch_computation(branch_num)->root_instruction();
if (branch_root->opcode() != HloOpcode::kTuple) {
return false;
}
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
absl::flat_hash_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
return false;
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
absl::flat_hash_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(
op->shape(), conditional, map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;

View File

@ -78,6 +78,52 @@ ENTRY main {
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
}
TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
body {
%p_body = (f32[2], bf16[2], s32[]) parameter(0)
%val = f32[2] get-tuple-element(p_body), index=0
%val2 = bf16[2] get-tuple-element(p_body), index=1
%const = s32[] constant(-1)
ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const)
}
condition {
%p_cond = (f32[2], bf16[2], s32[]) parameter(0)
%gte = s32[] get-tuple-element(%p_cond), index=2
%const = s32[] constant(42)
ROOT result = pred[] compare(%gte, %const), direction=EQ
}
on_true {
%arg_tuple.1 = f32[2] parameter(0)
%const = s32[] constant(42)
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
}
on_false {
%arg_tuple.1 = f32[2] parameter(0)
%const = s32[] constant(42)
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
%while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = f32[2] parameter(1)
ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
absl::string_view hlo_string =
R"(