Disabling the hoisting of converts in the presence of conditionals, where the branch roots are not of the same type.
PiperOrigin-RevId: 334705567 Change-Id: I5a94a89ae6dec40a5993defd1b290ae8bb0682a9
This commit is contained in:
parent
632fc11f67
commit
b9899df597
@ -351,11 +351,17 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Determining whether all branch roots are tuples
|
||||
for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
|
||||
HloInstruction* branch_root =
|
||||
conditional->branch_computation(branch_num)->root_instruction();
|
||||
if (branch_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
if (old_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
} else {
|
||||
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||
// Identify the gte using `index'.
|
||||
auto find_gte = [](const HloInstruction* conditional_result,
|
||||
@ -457,8 +463,8 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* op : hoist->operands()) {
|
||||
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
op->shape(), conditional, map_inst_to_tuple_index[op]));
|
||||
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
|
||||
map_inst_to_tuple_index[op]));
|
||||
new_operands.push_back(gte);
|
||||
}
|
||||
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||
@ -470,7 +476,6 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
// No need to explicitly delete a hoisted instruction since if its dead
|
||||
// then the subsequent DCE will remove it.
|
||||
}
|
||||
}
|
||||
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -78,6 +78,52 @@ ENTRY main {
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
body {
|
||||
%p_body = (f32[2], bf16[2], s32[]) parameter(0)
|
||||
%val = f32[2] get-tuple-element(p_body), index=0
|
||||
%val2 = bf16[2] get-tuple-element(p_body), index=1
|
||||
%const = s32[] constant(-1)
|
||||
ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const)
|
||||
}
|
||||
|
||||
condition {
|
||||
%p_cond = (f32[2], bf16[2], s32[]) parameter(0)
|
||||
%gte = s32[] get-tuple-element(%p_cond), index=2
|
||||
%const = s32[] constant(42)
|
||||
ROOT result = pred[] compare(%gte, %const), direction=EQ
|
||||
}
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = f32[2] parameter(0)
|
||||
%const = s32[] constant(42)
|
||||
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
|
||||
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
|
||||
ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
|
||||
}
|
||||
on_false {
|
||||
%arg_tuple.1 = f32[2] parameter(0)
|
||||
%const = s32[] constant(42)
|
||||
%add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1)
|
||||
%convert.2894 = bf16[2] convert(f32[2] %add.8493)
|
||||
%while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const)
|
||||
ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body
|
||||
}
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = f32[2] parameter(1)
|
||||
ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user