diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index f1529da7513..855e75a76e0 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -351,125 +351,130 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, return false; } - HloInstruction* old_root = - conditional->branch_computation(0)->root_instruction(); - if (old_root->opcode() != HloOpcode::kTuple) { - return false; - } else { - VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); - // Identify the gte using `index'. - auto find_gte = [](const HloInstruction* conditional_result, - int64 index) -> HloInstruction* { - for (HloInstruction* instr : conditional_result->users()) { - if (instr->opcode() != HloOpcode::kGetTupleElement) { - return nullptr; - } - if (instr->tuple_index() == index) { - return instr; - } - } - return nullptr; - }; - - // Captures tuple indices refering to converts to be rematerialized/hoisted. - absl::flat_hash_set kspecial_convert = FindSpecialConverts( - old_root, branch_count, conditional, is_layout_sensitive); - - // Exit if we cannot find any converts to be hoisted. - if (kspecial_convert.empty()) { + // Determining whether all branch roots are tuples + for (int branch_num = 0; branch_num < branch_count; ++branch_num) { + HloInstruction* branch_root = + conditional->branch_computation(branch_num)->root_instruction(); + if (branch_root->opcode() != HloOpcode::kTuple) { return false; } + } - TF_RETURN_IF_ERROR( - RestructureConditionalInstruction(conditional->parent(), conditional)); - - for (int branch = 0; branch < branch_count; branch++) { - old_root = conditional->branch_computation(branch)->root_instruction(); - absl::flat_hash_map map_inst_to_tuple_index; - std::vector new_operands(old_root->operand_count()); - absl::flat_hash_set to_hoist_set; - - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = - operand_num; + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); + // Identify the gte using `index'. + auto find_gte = [](const HloInstruction* conditional_result, + int64 index) -> HloInstruction* { + for (HloInstruction* instr : conditional_result->users()) { + if (instr->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; } - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - HloInstruction* hoist = old_root->mutable_operand(operand_num); - if (!kspecial_convert.contains(operand_num)) { - new_operands[operand_num] = old_root->mutable_operand(operand_num); - continue; - } - - to_hoist_set.insert(hoist); - int64 new_tuple_count = old_root->operand_count(); - - // Replace the hoisted instr in the tuple with the operand/operands. - // We will replace at least one of the operands of the hoist at the - // tuple place; the rest will be added at the end. - bool inplace = true; - CHECK(!hoist->operands().empty()); - for (HloInstruction* prod : hoist->operands()) { - if (inplace) { - map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; - new_operands[map_inst_to_tuple_index[hoist]] = prod; - inplace = false; - } else { - map_inst_to_tuple_index[prod] = new_tuple_count++; - new_operands.push_back(prod); - } - } + if (instr->tuple_index() == index) { + return instr; } + } + return nullptr; + }; - // Create the new root instruction. - HloComputation* cur_branch = conditional->branch_computation(branch); - HloInstruction* new_branch_root = - cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); - // The shape can vary since the operands to convert are now - // being returned through the branches' root. - cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); - TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + // Captures tuple indices refering to converts to be rematerialized/hoisted. + absl::flat_hash_set kspecial_convert = FindSpecialConverts( + old_root, branch_count, conditional, is_layout_sensitive); - // Only one of the branches needs to change the conditional->parent(). - if (branch != 0) { + // Exit if we cannot find any converts to be hoisted. + if (kspecial_convert.empty()) { + return false; + } + + TF_RETURN_IF_ERROR( + RestructureConditionalInstruction(conditional->parent(), conditional)); + + for (int branch = 0; branch < branch_count; branch++) { + old_root = conditional->branch_computation(branch)->root_instruction(); + absl::flat_hash_map map_inst_to_tuple_index; + std::vector new_operands(old_root->operand_count()); + absl::flat_hash_set to_hoist_set; + + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = + operand_num; + } + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + HloInstruction* hoist = old_root->mutable_operand(operand_num); + if (!kspecial_convert.contains(operand_num)) { + new_operands[operand_num] = old_root->mutable_operand(operand_num); continue; } - HloComputation* conditional_parent = conditional->parent(); - HloInstruction* newconditional = - conditional_parent->AddInstruction(HloInstruction::CreateConditional( - cur_branch->root_instruction()->shape(), - conditional->mutable_operand(0), - absl::MakeSpan(conditional->branch_computations()), - absl::MakeSpan(conditional->operands()).subspan(1))); - // Ensure that all the users of conditional refer to the new one. - TF_RETURN_IF_ERROR( - conditional->ReplaceAllUsesWithDifferentShape(newconditional)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); - conditional = newconditional; - // Add the hoisted instructions in the parent. - for (HloInstruction* hoist : to_hoist_set) { - VLOG(2) << "Hoisting instruction:" << hoist->ToString(); - int64 hoist_index = map_inst_to_tuple_index[hoist]; - // Find out the gte that captured the hoisted instr result. - HloInstruction* gte_hoist = find_gte(conditional, hoist_index); - CHECK(gte_hoist != nullptr); - std::vector new_operands; - for (HloInstruction* op : hoist->operands()) { - HloInstruction* gte = conditional_parent->AddInstruction( - HloInstruction::CreateGetTupleElement( - op->shape(), conditional, map_inst_to_tuple_index[op])); - new_operands.push_back(gte); + + to_hoist_set.insert(hoist); + int64 new_tuple_count = old_root->operand_count(); + + // Replace the hoisted instr in the tuple with the operand/operands. + // We will replace at least one of the operands of the hoist at the + // tuple place; the rest will be added at the end. + bool inplace = true; + CHECK(!hoist->operands().empty()); + for (HloInstruction* prod : hoist->operands()) { + if (inplace) { + map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; + new_operands[map_inst_to_tuple_index[hoist]] = prod; + inplace = false; + } else { + map_inst_to_tuple_index[prod] = new_tuple_count++; + new_operands.push_back(prod); } - HloInstruction* hoisted = conditional_parent->AddInstruction( - hoist->CloneWithNewOperands(hoist->shape(), new_operands)); - VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); - TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); } - // No need to explicitly delete a hoisted instruction since if its dead - // then the subsequent DCE will remove it. } + + // Create the new root instruction. + HloComputation* cur_branch = conditional->branch_computation(branch); + HloInstruction* new_branch_root = + cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); + // The shape can vary since the operands to convert are now + // being returned through the branches' root. + cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); + TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + + // Only one of the branches needs to change the conditional->parent(). + if (branch != 0) { + continue; + } + HloComputation* conditional_parent = conditional->parent(); + HloInstruction* newconditional = + conditional_parent->AddInstruction(HloInstruction::CreateConditional( + cur_branch->root_instruction()->shape(), + conditional->mutable_operand(0), + absl::MakeSpan(conditional->branch_computations()), + absl::MakeSpan(conditional->operands()).subspan(1))); + // Ensure that all the users of conditional refer to the new one. + TF_RETURN_IF_ERROR( + conditional->ReplaceAllUsesWithDifferentShape(newconditional)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); + conditional = newconditional; + // Add the hoisted instructions in the parent. + for (HloInstruction* hoist : to_hoist_set) { + VLOG(2) << "Hoisting instruction:" << hoist->ToString(); + int64 hoist_index = map_inst_to_tuple_index[hoist]; + // Find out the gte that captured the hoisted instr result. + HloInstruction* gte_hoist = find_gte(conditional, hoist_index); + CHECK(gte_hoist != nullptr); + std::vector new_operands; + for (HloInstruction* op : hoist->operands()) { + HloInstruction* gte = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement(op->shape(), conditional, + map_inst_to_tuple_index[op])); + new_operands.push_back(gte); + } + HloInstruction* hoisted = conditional_parent->AddInstruction( + hoist->CloneWithNewOperands(hoist->shape(), new_operands)); + VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); + TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); + } + // No need to explicitly delete a hoisted instruction since if its dead + // then the subsequent DCE will remove it. } VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString(); return true; diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index c974a0d005b..3b40acf54e3 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -78,6 +78,52 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement()))); } +TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + + body { + %p_body = (f32[2], bf16[2], s32[]) parameter(0) + %val = f32[2] get-tuple-element(p_body), index=0 + %val2 = bf16[2] get-tuple-element(p_body), index=1 + %const = s32[] constant(-1) + ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const) + } + + condition { + %p_cond = (f32[2], bf16[2], s32[]) parameter(0) + %gte = s32[] get-tuple-element(%p_cond), index=2 + %const = s32[] constant(42) + ROOT result = pred[] compare(%gte, %const), direction=EQ + } + + on_true { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + } + on_false { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + %while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body + } + ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = f32[2] parameter(1) + ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + } +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) { absl::string_view hlo_string = R"(