Disabling the hoisting of converts in the presence of conditionals, where the branch roots are not of the same type.

PiperOrigin-RevId: 334705567
Change-Id: I5a94a89ae6dec40a5993defd1b290ae8bb0682a9
This commit is contained in:
A. Unique TensorFlower 2020-09-30 16:41:07 -07:00 committed by TensorFlower Gardener
parent 632fc11f67
commit b9899df597
2 changed files with 158 additions and 107 deletions

View File

@ -351,11 +351,17 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
return false;
}
// Determining whether all branch roots are tuples
for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
HloInstruction* branch_root =
conditional->branch_computation(branch_num)->root_instruction();
if (branch_root->opcode() != HloOpcode::kTuple) {
return false;
}
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
if (old_root->opcode() != HloOpcode::kTuple) {
return false;
} else {
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
@ -457,8 +463,8 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(
op->shape(), conditional, map_inst_to_tuple_index[op]));
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
@ -470,7 +476,6 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;
}

View File

@ -78,6 +78,52 @@ ENTRY main {
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
}
TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
body {
%p_body = (f32[2], bf16[2], s32[]) parameter(0)
%val = f32[2] get-tuple-element(p_body), index=0
%val2 = bf16[2] get-tuple-element(p_body), index=1
%const = s32[] constant(-1)
ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const)
}
condition {
%p_cond = (f32[2], bf16[2], s32[]) parameter(0)
%gte = s32[] get-tuple-element(%p_cond), index=2
%const = s32[] constant(42)
ROOT result = pred[] compare(%gte, %const), direction=EQ
}
on_true {
%arg_tuple.1 = f32[2] parameter(0)
%const = s32[] constant(42)
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
}
on_false {
%arg_tuple.1 = f32[2] parameter(0)
%const = s32[] constant(42)
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
%while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = f32[2] parameter(1)
ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
absl::string_view hlo_string =
R"(