STT-tensorflow/tensorflow/compiler/xla/service/conditional_code_motion.cc
Yunxing Dai d4678c35bc Allow conditional code motion to move copies with same shape in order to fold them.
For cases like :
branch0 {
  ROOT copy ...
}
branch1 {
   ...
 }
 cond = conditional(branch0, branch1)
 copy = copy(cond)

If the two copies have the same shape, we can move copy across the conditional and cancel the two copies.

PiperOrigin-RevId: 342761683
Change-Id: I2a36785d353504b5166c336d4f00bcb4ef2ad19f
2020-11-16 18:12:26 -08:00

1396 lines
58 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/conditional_code_motion.h"
#include <iterator>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
namespace conditional_opt {
class BoundaryVisitor {
public:
// start with an existing conditional computation.
explicit BoundaryVisitor(HloInstruction* conditional) {
Boundary b(Boundary::Position::kInsideBranch);
b.mutable_operands().push_back(conditional);
worklist_.push_back(b);
}
// Start with an empty work list.
BoundaryVisitor() {}
// Get next boundary to visit.
Boundary PopNextBoundary() {
CHECK(!worklist_.empty());
Boundary b = worklist_.front();
worklist_.pop_front();
// if b is already visited, it must have multiple users and is already in
// new boundaries. Skip it. Only checking the first operand of b because b
// is expected to have at least one operand, and all the operands in b
// must be identical instructions from different branches for b to be moved.
while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) {
b = worklist_.front();
worklist_.pop_front();
}
visited_.insert(b.operands()[0]);
return b;
}
void AddToWorkList(const Boundary& b) {
CHECK(!b.operands().empty());
worklist_.push_back(b);
}
bool HasNextBoundary() {
while (!worklist_.empty()) {
Boundary b = worklist_.front();
if (!ContainsKey(visited_, b.operands()[0])) {
break;
}
worklist_.pop_front();
}
return !worklist_.empty();
}
private:
// worklist is the deque that contains instructions to be visited.
std::deque<Boundary> worklist_;
absl::flat_hash_set<HloInstruction*> visited_;
};
template <class OpCollection>
int64 CountNonLeafOps(const OpCollection& ops) {
absl::flat_hash_set<HloInstruction*> op_set;
for (auto op : ops) {
if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
op_set.insert(op);
}
}
return op_set.size();
}
// Returns estimation of potential reuses carried by a given pair of
// instructions. Use different integers to classify different levels
// of reuses This is used as a placeholder only, assuming all
// instructions can be fused to enable data reuses
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
// Reuses in some way work like forces that pull instructions
// towards each other. We use a number 0-10 to classify how strong the force
// is between a pair of operations. Given a group of instructions that can be
// moved together, if the forces inside a conditional are stronger, the group
// will be moved incide or remain inside the conditional; otherwise, it will
// be moved outside to or remain outside of the conditional.
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
<< op->ToString() << "=>" << user->ToString() << "\n";
switch (user->opcode()) {
case HloOpcode::kGetTupleElement:
return 0;
case HloOpcode::kConvert:
// Because convert is treated not moveable when following Dot or
// convolution, here if op is dot or convolution, they must be separated
// by a conditional boundary. Here we do not try to pull convert inside
// conditionals to be together with the dot or convolution.
switch (op->opcode()) {
case HloOpcode::kConvolution:
case HloOpcode::kDot:
return 0;
default:
break;
}
break;
default:
break;
}
switch (op->opcode()) {
// These instructions do not carry weight of reuse themselves.
case HloOpcode::kParameter:
case HloOpcode::kConstant:
case HloOpcode::kGetTupleElement:
return 0;
case HloOpcode::kConditional:
return 10;
default: {
// Assume the reuse decreases with increasing user count.
int count1 = CountNonLeafOps(op->users());
int count2 = CountNonLeafOps(user->operands());
return 10 / count1 / count2;
}
}
}
// Compare if the instructions to be visited at each branches are identical.
bool InstructionWithinBranchIdentical(
const std::vector<HloInstruction*>& instructions,
bool is_layout_sensitive) {
// Identical includes the shape of each operands are equal.
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
bool eq_operands = is_layout_sensitive
? ShapeUtil::Equal(a->shape(), b->shape())
: ShapeUtil::Compatible(a->shape(), b->shape());
return eq_operands;
};
auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
return *a == *b;
};
if (instructions.empty()) {
return false;
}
if (instructions[0]->IsCrossModuleAllReduce()) {
return std::all_of(
instructions.begin(), instructions.end(),
[&](HloInstruction* instruction) {
if (!instruction->IsCrossModuleAllReduce()) {
return false;
}
auto old_channel_id = instruction->channel_id();
instruction->set_channel_id(instructions[0]->channel_id());
bool eq_instructions = instructions[0]->Identical(
*instruction, eq_operand, eq_computations, is_layout_sensitive);
instruction->set_channel_id(old_channel_id);
return eq_instructions;
});
}
return std::all_of(instructions.begin(), instructions.end(),
[&](HloInstruction* instruction) {
return instructions[0]->Identical(
*instruction, eq_operand, eq_computations,
is_layout_sensitive);
});
}
// Copy the ith instruction in boundary to outside of conditional, or do the
// opposite (for moving in).
Status CopyInOrOutOfConditional(
Boundary& boundary, int64 dest_index, HloComputation* parent,
absl::flat_hash_map<HloInstruction*, Boundary>& hoisted_instructions) {
CHECK(dest_index == 0 || boundary.IsOutsideBranch());
HloInstruction* op = boundary.operands()[0];
absl::InlinedVector<HloInstruction*, 4> new_operands;
for (int i = 0; i < op->operands().size(); ++i) {
auto op_i = op->operands()[i];
VLOG(2) << "Looking for " << op_i->ToString() << "\n";
if (ContainsKey(hoisted_instructions, op_i)) {
auto new_op_i =
FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
} else {
switch (op_i->opcode()) {
case HloOpcode::kConstant: {
auto new_op_i = parent->AddInstruction(op_i->Clone());
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
break;
}
case HloOpcode::kGetTupleElement: {
auto gte = Cast<HloGetTupleElementInstruction>(op_i);
int64 index = gte->tuple_index();
HloInstruction* root = parent->root_instruction();
CHECK(root->opcode() == HloOpcode::kTuple &&
index < root->operand_count());
auto new_op_i = root->mutable_operand(index);
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
break;
}
default:
LOG(FATAL) << "Unexpected out-of-boundary instruction:"
<< op_i->ToString() << "\n";
}
}
}
HloInstruction* new_instruction = parent->AddInstruction(
op->CloneWithNewOperands(op->shape(), new_operands));
VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n";
// Maps the instruction outside of conditional to the instruction
// inside of the conditional.
for (HloInstruction* op : boundary.operands()) {
Boundary b2 = ContainsKey(hoisted_instructions, op)
? hoisted_instructions[op]
: Boundary(boundary.IsOutsideBranch()
? Boundary::Position::kInsideBranch
: Boundary::Position::kOutsideBranch);
b2.mutable_operands().push_back(new_instruction);
hoisted_instructions[op] = b2;
}
return Status::OK();
}
// Identify converts to be hoisted/rematerialized out of the branch
// computations.
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
int branch_count,
HloInstruction* conditional,
bool is_layout_sensitive) {
absl::flat_hash_set<int64> kspecial_convert;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
continue;
}
bool replica = true;
HloInstruction* kspecial_convert_candidate =
old_root->mutable_operand(operand_num);
// Check whether an identical candidate appears in other branches
for (int others = 1; others < branch_count; ++others) {
HloInstruction* others_root =
conditional->branch_computation(others)->root_instruction();
bool eq_shape =
is_layout_sensitive
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape())
: ShapeUtil::Compatible(
others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape());
if ((others_root->operand(operand_num)->opcode() ==
HloOpcode::kConvert) &&
eq_shape) {
// Nothing to be done.
} else {
replica = false;
break;
}
}
if (replica) {
kspecial_convert.insert(operand_num);
}
}
return kspecial_convert;
}
// Restructuring the conditional instruction as follows:
// i.e., %result = conditional() becomes
// x = conditional()
// y.{0..n} = gte(x, {0..n})
// z = tuple(y.0, y.1, ...y.n)
// Doing so ensures that we can accommodate the possible shape-change of the
// conditional when the instructions are hoisted.
Status RestructureConditionalInstruction(HloComputation* computation,
HloInstruction* conditional) {
HloInstruction* old_root = computation->root_instruction();
std::vector<HloInstruction*> new_operands;
int cur_index = 0;
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
++cur_index) {
new_operands.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
conditional, cur_index)));
}
HloInstruction* new_tuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
if (old_root == conditional) {
computation->set_root_instruction(new_tuple);
} else {
std::vector<HloInstruction*> new_tuple_users;
for (auto conditional_user : conditional->users()) {
auto is_new_gte = absl::c_find_if(
new_operands,
[&](HloInstruction* instr) { return instr == conditional_user; });
if (is_new_gte == new_operands.end()) {
new_tuple_users.push_back(conditional_user);
}
}
for (auto new_tuple_user : new_tuple_users) {
TF_RETURN_IF_ERROR(
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
}
}
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
return Status::OK();
}
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
bool is_layout_sensitive) {
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
}
// Determining whether all branch roots are tuples
for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
HloInstruction* branch_root =
conditional->branch_computation(branch_num)->root_instruction();
if (branch_root->opcode() != HloOpcode::kTuple) {
return false;
}
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
return false;
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
absl::flat_hash_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
}
}
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;
}
// Hoist identical ops out of the conditional. The definition of identical
// are the shape of the operands are identical and their properties are
// identical. Will start from the root instruction of each branch and get
// the identical ops to hoist.
StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
HloInstruction* conditional, std::vector<Boundary>& to_move_out,
std::vector<Boundary>& new_boundaries) {
if (to_move_out.empty()) {
return false;
}
VLOG(1) << "Modifying code--number of boundaries to move out:"
<< to_move_out.size() << "\n";
HloComputation* conditional_parent = conditional->parent();
// save the old users before add new conditional user instructions
std::vector<HloInstruction*> old_conditional_users = conditional->users();
// Maps instructions in the conditional body to instructions hoisted outside
// the conditional that compute the same value.
absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
// Insert GetTupleElement before the instructions whose operands might still
// be within the conditional.
VLOG(1) << "before opt:"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
int64 op_index = 0;
for (const Boundary& b : new_boundaries) {
HloInstruction* op = b.operands()[0];
CHECK(op != nullptr);
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
HloInstruction* gtr = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
op_index++));
Boundary b2(Boundary::Position::kOutsideBranch);
b2.mutable_operands().push_back(gtr);
hoisted_instructions[op] = b2;
}
// Copy boundary instructions out of the conditional.
// Visit the operands before its users and copy it, so that the copied
// user will point to the correct operand.
for (int64 i = to_move_out.size() - 1; i >= 0; i--) {
TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(
to_move_out[i], 0, conditional_parent, hoisted_instructions));
}
VLOG(2) << "Done copy branch instructions out\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
// Change original users of the conditional to use the correct operands.
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
for (auto user_instr : old_conditional_users) {
VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
int64 index = tuple_opd->tuple_index();
CHECK(old_root->operands().size() > index);
HloInstruction* old_opd = old_root->operands()[index];
VLOG(2) << "old opd = " << old_opd << "\n";
CHECK(ContainsKey(hoisted_instructions, old_opd));
HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
CHECK(old_opd != nullptr);
CHECK(new_opd != nullptr);
VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
}
VLOG(2) << "Done changing conditional users\n"
<< conditional_parent->ToString() << "\n";
// Create tuple element within each branch and set it as root.
int64 branch_count = conditional->branch_count();
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
std::vector<HloInstruction*> elements;
for (const auto& b1 : new_boundaries) {
HloInstruction* op = b1.operands()[i];
CHECK(op != nullptr);
VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
elements.push_back(op);
}
HloInstruction* tuple =
computation->AddInstruction(HloInstruction::CreateTuple(elements));
computation->set_root_instruction(tuple, true);
VLOG(2) << "computation is :" << computation->ToString() << "\n";
// Remove hoisted instructions from the branches.
for (const auto& b2 : to_move_out) {
auto instr_to_remove = b2.operands()[i];
// Double check to make sure it is safe to delete the instruction.
// Complications may arise due to some operations in the alternative
// branches (branches 1..n) being placed into the boundaries multiple
// times.
if (!computation->IsMarkedAsDead(instr_to_remove) &&
instr_to_remove->user_count() == 0) {
VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
}
}
}
// Change conditional instruction shape to the shape of the new root.
HloInstruction* new_root =
conditional->branch_computation(0)->root_instruction();
*conditional->mutable_shape() = new_root->shape();
//
VLOG(1) << "done moving instructions out of branches\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
return true;
}
// Hoist ops from outside of the conditional to inside the branches.
StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
HloInstruction* conditional, std::vector<Boundary>& to_move_in,
std::vector<Boundary>& new_boundaries) {
if (to_move_in.empty()) {
return false;
}
VLOG(1) << "Modifying code---number of boundaries to move in:"
<< to_move_in.size() << "\n";
VLOG(1) << "before opt:"
<< conditional->parent()->ToString(HloPrintOptions::Fingerprint())
<< "\n";
// Mapping instructions to be moved to their new representations.
absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
int64 to_move_in_size = to_move_in.size();
int64 branch_count = conditional->branch_count();
HloGetTupleElementInstruction* tuple_use =
DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
// If use_index is -1, the old conditional root entry used by to_move_in
// instructions still need to be included as an entry of the modified
// conditional root, and the new result of the to_move_in instructions
// need to be added as an extra entry of the modified root; otherwise, the
// old root entry will be replaced with the new result in the modified root.
// The entry replacement should be allowed only if tuple_use has <=1 users.
int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
? tuple_use->tuple_index()
: -1;
VLOG(2) << "Tuple use index = " << use_index << "\n";
// Number of old conditional entries still to be used outside.
// If conditional shape is not tuple, will create a tuple and use subscript
// 0 to save the old operand being used.
int64 op_index =
conditional->shape().IsTuple()
? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
: conditional->shape().tuple_shapes_size())
: 0;
// Use to map the tuple_use instruction to its operand;
Boundary b_opd_use(Boundary::Position::kInsideBranch);
Boundary b_old_root(Boundary::Position::kInsideBranch);
// Create a new root instruction in each branch.
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
auto old_root = computation->root_instruction();
b_old_root.mutable_operands().push_back(old_root);
std::vector<HloInstruction*> operands;
if (old_root->opcode() == HloOpcode::kTuple) {
// Use operands of old_root directly, so old_root can be removed later.
for (int i = 0; i < old_root->operand_count(); ++i) {
if (i != use_index) {
operands.push_back(old_root->operands()[i]);
} else { // Map conditional use to the tuple operand.
b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
}
}
} else if (old_root->shape().IsTuple()) {
// If old_root is not a kTuple but has tuple shape, elements within the
// tuple must be extracted first to be used by the new instructions.
const Shape& old_shape = old_root->shape();
for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) {
auto element =
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
old_shape.tuple_shapes(i), old_root, i));
if (i != use_index) {
operands.push_back(element);
} else {
b_opd_use.mutable_operands().push_back(element);
}
}
} else {
// If old_root is not a tuple and does not have tuple shape, use it
// to replace the conditional directly in the new computation.
b_opd_use.mutable_operands().push_back(conditional);
}
HloInstruction* new_root =
computation->AddInstruction(HloInstruction::CreateTuple(operands));
VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
computation->set_root_instruction(new_root,
/*accept_different_shape*/ true);
if (old_root->opcode() == HloOpcode::kTuple) {
TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
}
VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
}
// Update get tuple element index of the conditional.
if (use_index != -1) {
for (auto* user : conditional->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() > use_index) {
user->set_tuple_index(user->tuple_index() - 1);
}
}
}
hoisted_instructions[conditional] = b_old_root;
int64 cp_start = 0;
if (use_index >= 0) {
VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
hoisted_instructions[tuple_use] = b_opd_use;
}
cp_start = (tuple_use != nullptr) ? 1 : 0;
for (int64 to_move_index = cp_start; to_move_index < to_move_in_size;
to_move_index++) {
Boundary b_to_move = to_move_in[to_move_index];
HloInstruction* op = b_to_move.operands()[0];
CHECK(op != nullptr);
bool to_be_used_outside = true;
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
to_be_used_outside = false;
VLOG(2) << "Instruction is not to be used outside the branch\n";
}
Boundary b(Boundary::Position::kInsideBranch);
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
VLOG(2) << "Copying to branch: " << i << "\n";
TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
hoisted_instructions));
VLOG(2) << "Done:" << computation->ToString() << "\n";
if (to_be_used_outside) {
auto new_op = hoisted_instructions[op].operands()[i];
auto new_root = computation->root_instruction();
new_root->AppendOperand(new_op);
*new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
VLOG(2) << "Extending conditional root " << i << " : "
<< new_root->ToString() << "\n";
}
VLOG(2) << "After extending branch root: " << computation->ToString()
<< "\n";
}
if (to_be_used_outside) {
// Modify uses of instructions outside of the conditionals
HloInstruction* gtr = conditional->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
op_index++));
TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
if (conditional->parent()->root_instruction() == op) {
conditional->parent()->set_root_instruction(gtr);
}
}
}
VLOG(2) << "Done copying instructions inside branch: "
<< conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
// Change conditional instruction shape to the shape of the new root.
HloInstruction* new_root =
conditional->branch_computation(0)->root_instruction();
*conditional->mutable_shape() = new_root->shape();
VLOG(2) << "Before removing instructions:"
<< conditional->parent()->ToString() << "\n";
// Remove hoisted instructions from the branches.
for (int64 i = to_move_in_size - 1; i >= 0; i--) {
Boundary boundary_to_move_in = to_move_in[i];
HloInstruction* op = boundary_to_move_in.operands()[0];
if (op->user_count() == 0) {
VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
VLOG(2) << "Done removing boundary.\n";
}
}
// Reset shapes of user gtes to the new shape.
if (use_index != -1) {
for (auto* user : conditional->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
*user->mutable_shape() =
conditional->shape().tuple_shapes(user->tuple_index());
}
}
}
VLOG(1) << "Done moving instructions inside branches\n"
<< conditional->parent()->ToString(HloPrintOptions::Fingerprint())
<< "\n";
return true;
}
// Group single chains of operands or uses of boundaries into new boundaries
class GroupConnectedBoundaries {
private:
std::vector<Boundary> connected_boundaries_, new_boundaries_;
HloInstruction* conditional_;
HloComputation* conditional_parent_;
bool is_layout_sensitive_;
// Instructions that have been visited but are not going to be moved.
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
public:
explicit GroupConnectedBoundaries(
HloInstruction* conditional, bool is_layout_sensitive,
absl::flat_hash_map<HloInstruction*, int>& visited_count)
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive),
visited_count_(visited_count) {}
void clear_recently_visited() {
for (const auto& boundary : new_boundaries_) {
visited_count_.erase(boundary.operands()[0]);
}
}
// Returns true if `instruction` is worth hoisting.
bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
// This is needed for the "moving-in" transformation, to prevent the root
// of the parent computation (which contains the conditional) to be moved
// inside the conditional.
if (instruction->opcode() == HloOpcode::kTuple &&
instruction == conditional_parent_->root_instruction()) {
return false;
}
// TOOD[b/169182921] The following cost model is rather incomplete. Will
// need to extend to cover most of element-wise ops.
switch (instruction->opcode()) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce
// out of conditional for AR/CRS combine. If Convert is after other
// ops such as Dot or Convolutional, it is better to keep convert
// within conditional so that convert can be fused with Dot or
// Convolutional.
switch (instruction->operand(0)->opcode()) {
case HloOpcode::kAllReduce:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
return true;
default:
VLOG(2) << "Instruction is convert and its operand is not known to "
"be worth hoisting\n";
return false;
}
case HloOpcode::kGetTupleElement:
switch (instruction->operand(0)->opcode()) {
// do not move GTE if its operand is a parameter
case HloOpcode::kParameter:
return false;
default:
return true;
}
case HloOpcode::kAllReduce:
// It is not safe to move collective ops from outside to inside
// conditional branches, as it may cause synchronization problems,
// when different layouts are assigned to different branches.
return is_inside_branch;
case HloOpcode::kAbs:
case HloOpcode::kReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
case HloOpcode::kCopy:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kRsqrt:
case HloOpcode::kReshape:
case HloOpcode::kMinimum:
case HloOpcode::kMaximum:
return true;
default:
VLOG(2) << "Instruction is not known to be worth hoisting\n";
return false;
}
}
int64 ReusesBeforeBoundary(HloInstruction* user) {
int64 reuses = 0;
for (auto op : user->operands()) {
// The operand must be an instruction that is not going to be moved (if
// user is inside the conditional); otherwise it must be the conditional
// itself and its user must be outside of the conditional.
if (!ContainsKey(visited_count_, op) && op != conditional_) {
continue;
}
if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
if (op->opcode() == HloOpcode::kConditional) {
auto tuple = op->branch_computation(0)->root_instruction();
if (tuple->opcode() == HloOpcode::kTuple) {
auto index = tuple_gte->tuple_index();
CHECK(index < tuple->operand_count());
op = tuple->mutable_operand(index);
}
}
reuses += ReusesCarriedBy(op, user->users()[0]);
} else {
reuses += ReusesCarriedBy(op, user);
}
}
VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
<< "\n";
return reuses;
}
int64 ReusesAfterBoundary(HloInstruction* user) {
CHECK(user != nullptr);
auto all_users = user->users();
// For now, assume that if an instruction has multiple-consumers, it
// will not be reused, as the reuse may require duplication in
// fusion and so is expensive. If the situation changes in the future,
// some aspects of the overall algorithm need to be redesigned to
// accommandate the change.
if (all_users.size() > 1) {
VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
return 0;
}
if (!all_users.empty()) {
auto op = all_users[0];
int64 reuses = 0;
// Only count reuses that run through the conditional root.
if (op == conditional_->branch_computation(0)->root_instruction()) {
int64 index = op->operand_index(user);
for (auto op2 : conditional_->users()) {
// If the use is not get tuple, right now do not consider it.
if (op2->opcode() == HloOpcode::kGetTupleElement) {
auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
if (index == tuple_opd->tuple_index()) {
all_users = op2->users();
if (!all_users.empty()) {
reuses += ReusesCarriedBy(user, all_users[0]);
break;
}
}
}
}
} else if (ContainsKey(visited_count_, op)) {
reuses += ReusesCarriedBy(user, op);
}
VLOG(2) << "reuses after instruction " << user->ToString() << ":"
<< reuses << "\n";
return reuses;
}
return 0;
}
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
int64 reuses_before = 0, reuses_after = 0;
if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() &&
boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) {
// The only boundary of moving-in is the get_tuple_element op.
return -1;
}
// For cases like :
// branch0 {
// ROOT copy
// }
// branch1 {
// ...
// }
// cond = conditional(branch0, branch1)
// copy = copy(cond)
//
// We can fold the two copies thus reducing computation.
auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64 {
if (hlo->opcode() != HloOpcode::kCopy) {
return 0;
}
const HloGetTupleElementInstruction* gte =
DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
if (gte == nullptr) {
return 0;
}
const HloInstruction* conditional = gte->operand(0);
if (conditional != conditional_) {
return 0;
}
int64 benefit = 0;
for (auto* branch : conditional->called_computations()) {
HloInstruction* root = branch->root_instruction();
if (root->opcode() == HloOpcode::kTuple) {
const auto* tuple_operand = root->operand(gte->tuple_index());
if (tuple_operand->opcode() == HloOpcode::kCopy) {
if (Shape::Equal()(tuple_operand->operand(0)->shape(),
hlo->shape())) {
benefit += 10;
}
}
}
}
return benefit;
};
for (const Boundary& b : boundaries) {
auto op = b.operands()[0];
if (op == conditional_->branch_computation(0)->root_instruction()) {
continue;
}
VLOG(2) << "Benefit for " << op->ToString();
reuses_before += ReusesBeforeBoundary(op);
VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
reuses_after += ReusesAfterBoundary(op);
VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
}
int64 copy_folding_benefit = 0;
if (boundaries[0].IsOutsideBranch()) {
for (const Boundary& b : boundaries) {
auto op = b.operands()[0];
copy_folding_benefit += get_copy_folding_benefit(op);
}
}
VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
return -1;
} else if (boundaries[0].IsInsideBranch()) {
return reuses_after - reuses_before;
} else {
return reuses_before - reuses_after - 1 + copy_folding_benefit;
}
}
Boundary GetNextBoundary(const Boundary& b, int64 op_index) {
Boundary b2(b.GetPosition());
for (int j = 0; j < b.operands().size(); ++j) {
HloInstruction* inst = b.operands()[j];
CHECK(inst != nullptr);
HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
: inst->users()[op_index];
CHECK(op != nullptr);
b2.mutable_operands().push_back(op);
}
return b2;
}
// Checking whether it is safe to move a boundary when visited through a
// dependent already considered for moving.
bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
int64 next_boundary_count =
(next_boundary.IsInsideBranch())
? next_boundary.operands()[0]->user_count()
: CountNonLeafOps(next_boundary.operands()[0]->operands());
if (next_boundary_count <= 1) {
// If boundary has only a single or no dependent, safe to move.
return true;
} else {
if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
<< " because it has multiple dependents: "
<< next_boundary_count << "\n";
visited_count_[next_boundary.operands()[0]] = 1;
new_boundaries_.push_back(next_boundary);
} else {
auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
next_boundary);
if (pos != new_boundaries_.end() ||
next_boundary.operands().size() == 1) {
int count = ++visited_count_[next_boundary.operands()[0]];
if (count == next_boundary_count) {
VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
<< "\n"
<< " because all of its dependents have been visited: "
<< next_boundary_count << "\n";
visited_count_.erase(next_boundary.operands()[0]);
if (pos != new_boundaries_.end()) {
new_boundaries_.erase(pos);
}
return true;
}
} else {
VLOG(2) << "Skip incompatible multi-dependent boundary: "
<< next_boundary.ToString() << ":" << next_boundary_count
<< "\n";
}
}
}
return false;
}
// This function is reused both for moving the boundary outside or into a
// conditional. As the result, the readability is somewhat compromised.
// It might be nice to refactor this function to factor the outside-inside
// considerations into separate function pointer parameters to improve
// readability.
void AddBoundaries(const Boundary& boundary) {
BoundaryVisitor visitor;
visitor.AddToWorkList(boundary);
while (visitor.HasNextBoundary()) {
Boundary b = visitor.PopNextBoundary();
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
b.operands(), is_layout_sensitive_)) &&
IsSafeToMoveBoundary(b) &&
WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
connected_boundaries_.push_back(b);
VLOG(2) << "boundary can be moved\n";
int64 operand_count = (b.IsInsideBranch())
? b.operands()[0]->operand_count()
: b.operands()[0]->users().size();
for (int i = 0; i < operand_count; i++) {
Boundary next_boundary = GetNextBoundary(b, i);
VLOG(2) << "Add operand/user " << i << " to visit later\n";
visitor.AddToWorkList(next_boundary);
}
} else {
VLOG(2) << "boundary cannot be moved\n";
visited_count_[b.operands()[0]] = 1;
new_boundaries_.push_back(b);
}
}
}
std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
const Boundary& b) {
// At the beginning of optimization, a conditional itself is added to a
// worklist. Here the conditional is expanded into two sets of boundaries:
// the first set contains the boundary that is inside branches and
// contains the root of all branches; the second set of boundaries
// contains all the users of the conditional.
HloInstruction* inst = b.operands()[0];
if (inst == conditional) {
int branch_count = inst->branch_count();
// Add conditional roots as a new boundary to visit.
Boundary boundary_in(Boundary::Position::kInsideBranch);
for (int i = 0; i < branch_count; i++) {
HloComputation* branch_computation = inst->branch_computation(i);
HloInstruction* root_inst = branch_computation->root_instruction();
CHECK(root_inst != nullptr);
boundary_in.mutable_operands().push_back(root_inst);
}
new_boundaries_.push_back(boundary_in);
// Add conditional users as new boundaries to visit.
for (auto u : inst->users()) {
Boundary boundary_in(Boundary::Position::kOutsideBranch);
boundary_in.mutable_operands().push_back(u);
new_boundaries_.push_back(boundary_in);
}
} else {
AddBoundaries(b);
}
return connected_boundaries_;
}
void AddNewBoundaries(std::vector<Boundary>& b) {
b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
}
};
ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
HloInstruction* conditional, const Boundary& cur_boundary,
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
visited_count);
auto move_in_or_out =
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
if (!move_in_or_out.empty()) {
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
VLOG(2) << "benefit of moving in or out "
<< cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
if (benefit >= 0) {
new_boundaries.clear();
connect.AddNewBoundaries(new_boundaries);
// The whole sequence in move_in_or_out is either all moving into a
// conditional, or all moving out of a conditional. So looking only
// at the first entry of the sequence is sufficient to know which
// direction the move is intended.
to_move = move_in_or_out;
return Decision(to_move[0].IsInsideBranch()
? Decision::Direction::kMoveOutOfBranch
: Decision::Direction::kMoveIntoBranch,
benefit);
} else {
connect.clear_recently_visited();
}
} else {
connect.AddNewBoundaries(new_boundaries);
}
return Decision(Decision::Direction::kNoChange, 0);
}
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
// Use to support debugging of optimization, by disabling the opt after it has
// been applied a pre-determined times (to isolate impact of transformations).
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
return false;
}
bool changed = false;
bool cleanup_changed = false;
{
HloPassPipeline subpipeline("before_conditional_code_motion");
subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
cleanup_changed |= cleanup_changed_now;
}
// Gather all the conditional ops in the module ahead of time, to avoid
// potential complications of modifying the code that affecting traversal.
std::vector<HloInstruction*> conditional_ops;
// Track how many times each branch computation is shared.
absl::flat_hash_map<HloComputation*, int> conditional_computations;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
int branch_count = instr->branch_count();
for (int i = 0; i < branch_count; ++i) {
HloComputation* branch_i = instr->branch_computation(i);
if (ContainsKey(conditional_computations, branch_i)) {
conditional_computations[branch_i]++;
} else {
conditional_computations[branch_i] = 0;
}
}
if (instr->shape().IsTuple()) {
bool can_change_tuple_shape = true;
for (auto user : instr->users()) {
VLOG(2) << "user is : " << user->ToString() << "\n";
if (user->opcode() != HloOpcode::kGetTupleElement) {
can_change_tuple_shape = false;
}
}
if (can_change_tuple_shape) {
conditional_ops.push_back(instr);
}
} else {
conditional_ops.push_back(instr);
}
}
}
}
// Use to collect mappings between cloned instructions.
HloCloneContext clone_context(module);
for (HloInstruction* conditional : conditional_ops) {
int branch_count = conditional->branch_count();
// check for shared conditional computations
bool conditional_is_shared = false;
for (int i = 0; i < branch_count; ++i) {
HloComputation* branch_i = conditional->branch_computation(i);
if (conditional_computations[branch_i] > 0) {
conditional_is_shared = true;
break;
}
}
// Boundaries to move out or to move into the branches.
std::vector<std::vector<Boundary> > to_move_out, to_move_in;
std::vector<std::vector<Boundary> > new_boundaries_for_moveout;
std::vector<std::vector<Boundary> > new_boundaries_for_movein;
// Number of times each instruction has been visited for moving.
absl::flat_hash_map<HloInstruction*, int> visited_count;
int benefit_move_out = 0, benefit_move_in = 0;
Decision::Direction final_d = Decision::Direction::kNoChange;
// The conditional is moved into a worklist as the seed (starting point).
// The conditional will be expanded into multiple seeds (starting points),
// its roots and its users, when it is visited by GroupConnectedBoundaries.
// A NO_CHANGE decision will always be returned for the conditional itself,
// so that the other seeding boundaries can be visited in turn.
BoundaryVisitor visitor(conditional);
VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
// Try visit all the boundaries, collect the analysis results, and save
// all the benefitical non-conflicting decisions. If two decisions conflict
// with each other, save the more benefitical one.
while (visitor.HasNextBoundary()) {
std::vector<Boundary> to_move, next_boundary;
Boundary boundary = visitor.PopNextBoundary();
VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
visited_count);
switch (d.GetDirection()) {
case Decision::Direction::kMoveOutOfBranch:
VLOG(2) << "Local Decision is move out of branch\n";
to_move_out.push_back(to_move);
new_boundaries_for_moveout.push_back(next_boundary);
benefit_move_out += d.GetBenefit();
if (benefit_move_out >= benefit_move_in) {
final_d = Decision::Direction::kMoveOutOfBranch;
VLOG(2) << "Current Decision is move out of branch ("
<< to_move_out.size() << ")\n";
} else {
VLOG(2) << "Current Decision remains move into branch\n";
}
break;
case Decision::Direction::kMoveIntoBranch:
VLOG(2) << "Decision is move into branch\n";
to_move_in.push_back(to_move);
new_boundaries_for_movein.push_back(next_boundary);
benefit_move_in += d.GetBenefit();
if (benefit_move_out >= benefit_move_in) {
VLOG(2) << "Current Decision remains move out of branch\n";
} else {
final_d = Decision::Direction::kMoveIntoBranch;
VLOG(2) << "Current Decision is move into branch ("
<< to_move_in.size() << ")\n";
}
break;
case Decision::Direction::kNoChange:
VLOG(2) << "Decision is no change\n";
for (const Boundary& b : next_boundary) {
visitor.AddToWorkList(b);
VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
<< "\n";
}
break;
}
}
// If modification is to be made, need to clone the shared branches.
if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
for (int i = 0; i < branch_count; ++i) {
HloComputation* branch_i = conditional->branch_computation(i);
if (conditional_computations[branch_i] > 0) {
// Cloning is absolutely needed if the computation is shared by
// different branches, but the cloning can be potentially avoided
// if the sharing is only among branches of the same conditional.
// If cloning these branches causes a problem due to space issues,
// a fix can pass a vector of unique branches to the actual
// transformations, as an alternative representation of the
// conditional branches to be modified. Right now we assume the
// overhead of cloning is minimal since later stages of the compiler
// inline all the computations anyway.
HloComputation* clone_i =
conditional->parent()->parent()->AddEmbeddedComputation(
branch_i->Clone("clone", &clone_context));
conditional->set_branch_computation(i, clone_i);
conditional_computations[branch_i]--;
// Need to translate the analysis result to generate correct result.
auto update_boundary = [&](Boundary& boundary) {
auto cloned_instr =
clone_context.FindInstruction(boundary.operands()[i]);
CHECK(cloned_instr != nullptr);
VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
<< "\n";
boundary.mutable_operands()[i] = cloned_instr;
VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
<< "\n";
};
// Only boundaries to move out need to be updated.
if (final_d == Decision::Direction::kMoveOutOfBranch) {
for (int i = 0; i < to_move_out.size(); ++i) {
std::vector<Boundary>& m = to_move_out[i];
std::for_each(m.begin(), m.end(), update_boundary);
}
for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
std::vector<Boundary>& m = new_boundaries_for_moveout[i];
std::for_each(m.begin(), m.end(), update_boundary);
}
}
}
}
VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
<< "\n";
}
// At most one of to_move_out or to_move_in can be non-empty, since there is
// only one optimization decision.
if (final_d == Decision::Direction::kMoveOutOfBranch) {
CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
for (int i = 0; i < to_move_out.size(); ++i) {
TF_ASSIGN_OR_RETURN(bool result,
MoveInstructionOut(conditional, to_move_out[i],
new_boundaries_for_moveout[i]));
changed |= result;
}
VLOG(2) << "Done moving out of branches " << to_move_out.size()
<< " times. \n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
break;
}
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
for (int i = 0; i < to_move_in.size(); ++i) {
TF_ASSIGN_OR_RETURN(bool result,
MoveInstructionIn(conditional, to_move_in[i],
new_boundaries_for_movein[i]));
changed |= result;
}
VLOG(2) << "Done moving into branches " << to_move_in.size()
<< " times. \n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
break;
}
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
// Invoke special handling for convert rematerialization/hoisting
// We need to make sure no sharing is present in the branches because no
// cloning has been done by the earlier analysis.
// TOOD[b/165848866]: extend solution to handle cloning for special move.
TF_ASSIGN_OR_RETURN(
bool convert_result,
ConvertSpecialMove(conditional, is_layout_sensitive_));
if (convert_result) {
VLOG(2) << "Done special moving of convert\n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching "
"0.\n";
})) {
break;
}
}
changed |= convert_result;
}
}
if (changed) {
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
VLOG(2) << "starting after motion passes: DCE\n";
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
cleanup_changed |= cleanup_changed_now;
}
if (cleanup_changed) {
VLOG(2) << "subpipeline cleanup have modified code\n";
}
return changed;
}
} // namespace conditional_opt
} // namespace xla