Rollback of 4abf3012f1 which accidentally breaks compilation of tensorflow/compiler/xla/service/conditional_code_motion.cc

PiperOrigin-RevId: 320840323
Change-Id: Ifaf6dd56f2d7f885a85213097d77da901b216677
This commit is contained in:
Mihai Maruseac 2020-07-12 07:39:22 -07:00 committed by TensorFlower Gardener
parent f356a508cf
commit f6984195e3
3 changed files with 396 additions and 532 deletions

View File

@ -46,63 +46,161 @@ limitations under the License.
namespace xla {
namespace conditional_opt {
namespace {
struct ConditionalBoundary {
ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr)
: operand(op), operand_index(op_index), user(usr) {}
// `operand` is one of `user`'s operand.
// Instruction that remains in the conditional but one of its user
// is moved out of conditonal.
HloInstruction* operand;
// operand_index for `operand` in the `user`.
int64 operand_index;
// Instruction that moved out of conditional.
HloInstruction* user;
};
// Visit the root instructions to its operands follow BFS.
// Will visit an instructions after all its users have been visited. Parameters
// are not visited.
class BoundaryVisitor {
class BranchVisitor {
public:
// start with an existing conditional computation.
explicit BoundaryVisitor(HloInstruction* conditional) {
Boundary b(Boundary::Position::kInsideBranch);
b.Operands().push_back(conditional);
worklist_.push_back(b);
explicit BranchVisitor(const HloComputation* branch_computation) {
HloInstruction* root_inst = branch_computation->root_instruction();
worklist_.push_back(root_inst);
visited_.insert(root_inst);
for (auto parameter_inst : branch_computation->parameter_instructions()) {
parameter_instructions_.insert(parameter_inst);
}
}
// Start with an empty work list.
BoundaryVisitor() {}
// Get next intruction to visit.
Boundary PopNextBoundary() {
CHECK(!worklist_.empty());
Boundary inst = worklist_.front();
worklist_.pop_front();
return inst;
}
void AddToWorkList(const Boundary& b) {
CHECK_GT(b.Operands().size(), 0);
worklist_.push_back(b);
HloInstruction* GetNextInstruction() {
if (!worklist_.empty()) {
HloInstruction* inst = worklist_.front();
worklist_.pop_front();
return inst;
}
return nullptr;
}
bool HasNextBoundary() const { return !worklist_.empty(); }
// Add operands of one instruction to worklist for further visit.
void AddInstructionOperands(HloInstruction* inst) {
int64 operand_count = inst->operand_count();
for (int i = 0; i < operand_count; i++) {
HloInstruction* operand = inst->mutable_operand(i);
if (ContainsKey(visited_, operand)) {
continue;
}
bool all_user_visited = std::all_of(
operand->users().begin(), operand->users().end(),
[&](HloInstruction* user) { return ContainsKey(visited_, user); });
if (!all_user_visited) {
continue;
}
// Do not visit parameter_instructions.
if (ContainsKey(parameter_instructions_, operand)) {
// Add the operand and this instruction to the boundaries.
boundaries_.emplace_back(operand, i, inst);
continue;
}
worklist_.push_back(operand);
visited_.insert(operand);
}
}
// Add instruction and its users to conditional boundaries.
void AddInstructionToBoundary(HloInstruction* inst) {
for (auto user : inst->users()) {
boundaries_.emplace_back(inst, user->operand_index(inst), user);
}
}
// Add instruction to the to be removed instructions set and vector.
void AddInstructionToHoist(HloInstruction* inst) {
instructions_to_hoist_set_.insert(inst);
instructions_to_hoist_.emplace_back(inst);
}
// If visitor has next instruction to visit.
bool HasNextInstruction() const { return !worklist_.empty(); }
// If there is no hoist intruction.
int64 HoistInstructionSize() { return instructions_to_hoist_.size(); }
// Get boundaries of this branch.
const std::vector<ConditionalBoundary>& boundaries() const {
return boundaries_;
}
// Get instructions to hoist in this branch.
const std::vector<HloInstruction*>& instructions_to_hoist() const {
return instructions_to_hoist_;
}
// Get hoist instruction set in this branch.
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set() const {
return instructions_to_hoist_set_;
}
private:
// worklist is the deque that contains instructions to be visited.
std::deque<Boundary> worklist_;
std::deque<HloInstruction*> worklist_;
// instructions that has been visited.
std::unordered_set<HloInstruction*> visited_;
// parameter instructions of the branch.
std::unordered_set<HloInstruction*> parameter_instructions_;
// Boundaries contains the set of instructions that its operand is within
// conditional but it can be hoist out of conditional.
std::vector<ConditionalBoundary> boundaries_;
// Instructions to hoist.
std::unordered_set<HloInstruction*> instructions_to_hoist_set_;
// Instructions to hoist, the order within this vector is BFS and
// an instruction's order will always be after its users.
std::vector<HloInstruction*> instructions_to_hoist_;
};
// Returns estimation of potential reuses carried by a given instruction.
// Use different integers to classify different levels of reuses
// This is used as a placeholder only, assuming all instructions can be
// fused to enable data reuses
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: "
<< op->ToString() << "=>" << user->ToString() << "\n";
switch (user->opcode()) {
case HloOpcode::kGetTupleElement:
return 0;
default:
break;
// Returns true if `instruction` is worth hoisting out.
bool WorthHoisting(HloInstruction* instruction) {
for (const auto* operand : instruction->operands()) {
// Only move out instructions that won't share the same operand
// to avoid copy of the operand.
if (operand->user_count() > 1) {
return false;
}
}
switch (op->opcode()) {
// These instructions are lightweight and easy to fuse.
case HloOpcode::kConstant:
return 0;
default:
// Assume fusion will not happen anyway if user count > 1)
if (op->user_count() > 1) {
return 0;
switch (instruction->opcode()) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce out
// of conditional for AR/CRS combine. If Convert is after other ops such
// as Dot or Convolutional, it is better to keep convert within
// conditional so that convert can be fused with Dot or Convolutional.
//
// TODO(b/154283721): figure out the scenario when convert can be fused
// with AllReduce out of conditional.
if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) {
return true;
}
return 10;
return false;
case HloOpcode::kAllReduce:
case HloOpcode::kAdd:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kGetTupleElement:
return true;
default:
return false;
}
}
@ -122,7 +220,7 @@ bool InstructionWithinBranchIdentical(
return *a == *b;
};
if (instructions.empty()) {
if (instructions[0] == nullptr) {
return false;
}
@ -150,27 +248,109 @@ bool InstructionWithinBranchIdentical(
});
}
// Copy identical instructions within conditional outside of conditional.
Status CopyOutOfConditional(
Boundary& boundary, HloComputation* conditional_parent,
absl::flat_hash_map<HloInstruction*, HloInstruction*>&
hoisted_instructions) {
// Insert GetTupleElement before the instructions whose operands might still
// be within the conditional.
HloInstruction* op = boundary.Operands()[0];
absl::InlinedVector<HloInstruction*, 4> new_operands;
for (int i = 0; i < op->operands().size(); ++i) {
auto op_i = op->operands()[i];
VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n";
CHECK(ContainsKey(hoisted_instructions, op_i));
new_operands.push_back(FindOrDie(hoisted_instructions, op_i));
// Returns if all the visitors/branches has next instruction to visit.
bool HasNextInstruction(const std::vector<BranchVisitor>& visitors) {
bool has_next = true;
for (const auto& visitor : visitors) {
has_next &= visitor.HasNextInstruction();
}
HloInstruction* new_instruction = conditional_parent->AddInstruction(
op->CloneWithNewOperands(op->shape(), new_operands));
// Maps the instruction outside of conditional to the instruction
// inside of the conditional.
for (HloInstruction* op : boundary.Operands()) {
hoisted_instructions[op] = new_instruction;
return has_next;
}
// Create tuple element as the new root of the branch. The tuple will contain
// the operands that can't move out of conditional but its user will be moved
// out of conditional.
HloInstruction* CreateNewRoot(
const std::vector<ConditionalBoundary>& boundaries,
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set,
HloComputation* computation) {
std::vector<HloInstruction*> elements;
elements.reserve(boundaries.size());
for (auto boundary : boundaries) {
if (ContainsKey(instructions_to_hoist_set, boundary.user)) {
elements.push_back(boundary.operand);
}
}
return computation->AddInstruction(HloInstruction::CreateTuple(elements));
}
// Copy identical instructions within conditional outside of conditional.
void CopyIdenticalInstructionsOutOfConditional(
const std::vector<HloInstruction*>& instructions_to_hoist,
HloComputation* conditional_parent,
absl::flat_hash_map<HloInstruction*, HloInstruction*>*
hoisted_instructions) {
int64 instructions_size = instructions_to_hoist.size();
// Visit the operands before its users and copy it, so that the copied
// user will point to the correct operand.
for (int64 i = instructions_size - 1; i >= 0; i--) {
HloInstruction* old_instruction = instructions_to_hoist[i];
auto get_new_operand = [&](HloInstruction* old_operand) {
// If the operand can't be found in `instructions_to_hoist`, this
// operand will be in the `boundaries`, GetTupleElement instructions
// will be added later to replace this operand.
if (!ContainsKey(*hoisted_instructions, old_operand)) {
return old_operand;
}
return FindOrDie(*hoisted_instructions, old_operand);
};
absl::InlinedVector<HloInstruction*, 4> new_operands;
absl::c_transform(old_instruction->operands(),
std::back_inserter(new_operands), get_new_operand);
HloInstruction* new_instruction = conditional_parent->AddInstruction(
old_instruction->CloneWithNewOperands(old_instruction->shape(),
new_operands));
// Maps the instruction outside of conditional to the instruction
// inside of the conditional.
InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
}
}
// If there are instructions to hoist, the root of the conditional must be
// moved out. Change the users of the conditional to the hoisted instruction
// of the new root.
Status ChangeConditionalUsers(
HloInstruction* conditional, HloInstruction* old_root,
const absl::flat_hash_map<HloInstruction*, HloInstruction*>&
hoisted_instructions) {
HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root);
TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root));
return Status::OK();
}
// Insert GetTupleElement before the instructions whose operands might still
// be within the conditional.
Status CreateGetTupleElementAfterConditional(
const std::vector<ConditionalBoundary>& boundaries,
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set,
const absl::flat_hash_map<HloInstruction*, HloInstruction*>&
hoisted_instructions,
HloInstruction* conditional, HloComputation* computation) {
int boundary_instruction_size = boundaries.size();
// Inserts GetTupleElement before the boundary instructions.
for (int i = 0; i < boundary_instruction_size; i++) {
HloInstruction* gte =
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
boundaries[i].operand->shape(), conditional, i));
HloInstruction* new_instruction =
FindOrDie(hoisted_instructions, boundaries[i].user);
TF_RETURN_IF_ERROR(
new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte));
}
return Status::OK();
}
// Remove instructions to be hoisted out of the branch computation.
Status RemoveInstructionFromComputation(
const std::vector<HloInstruction*>& instructions_to_hoist,
HloComputation* branch) {
// Will visit the instructions after its users.
for (auto* instruction : instructions_to_hoist) {
TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction));
}
return Status::OK();
}
@ -394,359 +574,128 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
// are the shape of the operands are identical and their properties are
// identical. Will start from the root instruction of each branch and get
// the identical ops to hoist.
StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
HloInstruction* conditional, std::vector<Boundary>& to_move_out,
std::vector<Boundary>& new_boundaries) {
if (to_move_out.empty()) {
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
bool is_layout_sensitive) {
VLOG(1) << " visiting conditional:" << conditional->ToString();
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
}
VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n";
HloComputation* conditional_parent = conditional->parent();
// save the old users before add new conditional user instructions
std::vector<HloInstruction*> old_conditional_users = conditional->users();
absl::flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
// Maps instructions in the conditional body to instructions hoisted outside
// the conditional that compute the same value.
VLOG(2) << "before opt:"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
int64 op_index = 0;
for (Boundary b : new_boundaries) {
HloInstruction* op = b.Operands()[0];
CHECK(op != nullptr);
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
HloInstruction* gtr = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
op_index++));
hoisted_instructions[op] = gtr;
std::vector<BranchVisitor> visitors;
visitors.reserve(branch_count);
// Visit instructions from the root instruction to the operands using BFS.
for (int i = 0; i < branch_count; i++) {
visitors.emplace_back(BranchVisitor(conditional->branch_computation(i)));
}
// Copy boundary instructions out of the conditional.
// Visit the operands before its users and copy it, so that the copied
// user will point to the correct operand.
for (int64 i = to_move_out.size() - 1; i >= 0; i--) {
TF_RETURN_IF_ERROR(CopyOutOfConditional(to_move_out[i], conditional_parent,
hoisted_instructions));
// The instructions to be visited within each branch.
std::vector<HloInstruction*> front_instructions(branch_count);
while (HasNextInstruction(visitors)) {
for (int i = 0; i < branch_count; i++) {
front_instructions[i] = visitors[i].GetNextInstruction();
}
// If two instructions has the same shape, opcode and its operands has the
// same shape, then this instruction can be moved out of conditional.
if (WorthHoisting(front_instructions[0]) &&
InstructionWithinBranchIdentical(front_instructions,
is_layout_sensitive)) {
for (int i = 0; i < branch_count; i++) {
visitors[i].AddInstructionOperands(front_instructions[i]);
visitors[i].AddInstructionToHoist(front_instructions[i]);
}
} else {
for (int i = 0; i < branch_count; i++) {
// If the ops are not identical, these ops and its users will
// be in the boundaries` of the conditional. These ops will be stayed
// within the conditional, but one its only user will be moved out
// of conditional.
visitors[i].AddInstructionToBoundary(front_instructions[i]);
}
}
}
VLOG(2) << "Done copy branch instructions out\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
// Change original users of the conditional to use the correct operands.
if (visitors[0].HoistInstructionSize() < 1) {
return false;
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
for (auto user_instr : old_conditional_users) {
CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
auto tuple_opd = down_cast<HloGetTupleElementInstruction*>(user_instr);
int64 index = tuple_opd->tuple_index();
HloInstruction* old_opd = old_root->operands()[index];
HloInstruction* new_opd = hoisted_instructions[old_opd];
CHECK(old_opd != nullptr);
CHECK(new_opd != nullptr);
TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
}
HloComputation* conditional_parent = conditional->parent();
// Maps instructions in the conditional body to instructions hoisted outside
// the conditional that compute the same value.
absl::flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
// Copy identical instructions out of the conditional.
CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(),
conditional_parent,
&hoisted_instructions);
// If there are instructions to hoist, the root of the conditional must be
// moved out. Change the users of the conditional to the hoisted instruction
// of the new root.
TF_RETURN_IF_ERROR(
ChangeConditionalUsers(conditional, old_root, hoisted_instructions));
// Create tuple element within each branch and set it as root.
int64 branch_count = conditional->branch_count();
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
std::vector<HloInstruction*> elements;
for (auto b1 : new_boundaries) {
HloInstruction* op = b1.Operands()[i];
VLOG(1) << "branch count=" << i << "\n";
CHECK(op != nullptr);
VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n";
elements.push_back(op);
}
HloInstruction* tuple =
computation->AddInstruction(HloInstruction::CreateTuple(elements));
computation->set_root_instruction(tuple, true);
VLOG(2) << "computation is :" << computation->ToString() << "\n";
// Remove hoisted instructions from the branches.
for (auto b2 : to_move_out) {
VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.Operands()[i]));
}
HloInstruction* tuple = CreateNewRoot(
visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(),
conditional->branch_computation(i));
conditional->branch_computation(i)->set_root_instruction(tuple, true);
}
// Changes conditional instruction shape to the shape of the new root.
*conditional->mutable_shape() =
conditional->branch_computation(0)->root_instruction()->shape();
// Insert GetTupleElement before the instructions whose operands might still
// be within the conditional.
TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional(
visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(),
hoisted_instructions, conditional, conditional_parent));
// Remove hoist instructions from the branches.
for (int i = 0; i < branch_count; i++) {
TF_RETURN_IF_ERROR(
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
conditional->branch_computation(i)));
}
// Change conditional instruction shape to the shape of the new root.
HloInstruction* new_root =
conditional->branch_computation(0)->root_instruction();
*conditional->mutable_shape() = new_root->shape();
//
VLOG(2) << "done moving instructions out of branches\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
return true;
}
// Group single chains of operands or uses of boundaries into new boundaries
class GroupConnectedBoundaries {
private:
std::unordered_set<HloInstruction*> visited_;
std::vector<Boundary> connected_boundaries_, new_boundaries_;
HloInstruction* conditional_;
bool is_layout_sensitive_;
public:
explicit GroupConnectedBoundaries(HloInstruction* conditional,
bool is_layout_sensitive)
: conditional_(conditional), is_layout_sensitive_(is_layout_sensitive) {}
// Returns true if `instruction` is worth hoisting out.
bool WorthHoisting(HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce out
// of conditional for AR/CRS combine. If Convert is after other ops such
// as Dot or Convolutional, it is better to keep convert within
// conditional so that convert can be fused with Dot or Convolutional.
//
// TODO(b/154283721): figure out the scenario when convert can be fused
// with AllReduce out of conditional.
switch (instruction->operand(0)->opcode()) {
case HloOpcode::kAllReduce:
case HloOpcode::kReshape:
return true;
default:
VLOG(1) << "Instruction is convert and its operand is not know to "
"be worth hoisting\n";
return false;
}
case HloOpcode::kAllReduce:
case HloOpcode::kAdd:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
return true;
default:
VLOG(1) << "Instruction is not known to be worth hoisting\n";
return false;
}
}
// Calculates the degree of reuses carried by a pair of conditional
// boundaries, if b1 is inside a conditional and b2 is outside.
int64 ReusesBeforeBoundary(HloInstruction* user) {
int64 reuses = 0;
for (auto op : user->operands()) {
// Only consider single-user cases as reuseable.
if (ContainsKey(visited_, op) && op->user_count() == 1) {
reuses += ReusesCarriedBy(op, user);
}
}
VLOG(1) << "cost to be paied after moving out" << user->ToString() << ":"
<< reuses << "\n";
return reuses;
}
int64 ReusesAfterBoundary(HloInstruction* user) {
CHECK(user != nullptr);
auto all_users = user->users();
// For now, assume that if an instruction has multiple-consumers, it will
// not be reused (the reuse currently requires duplication in fusion and so
// is expensive).
if (all_users.size() > 1) {
return 0;
}
if (!all_users.empty()) {
auto op = all_users[0];
int64 reuses = 0;
// Only count reuses that run through the conditional root.
if (op == conditional_->branch_computation(0)->root_instruction()) {
int64 index = op->operand_index(user);
for (auto op2 : conditional_->users()) {
CHECK(op2->opcode() == HloOpcode::kGetTupleElement);
auto tuple_opd = down_cast<HloGetTupleElementInstruction*>(op2);
if (index == tuple_opd->tuple_index()) {
all_users = op2->users();
if (!all_users.empty()) {
reuses += ReusesCarriedBy(user, all_users[0]);
break;
}
}
}
}
VLOG(1) << "reuses to be gained after moving " << user->ToString() << ":"
<< reuses << "\n";
return reuses;
}
return 0;
}
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
int64 reuses_before = 0, reuses_after = 0;
for (Boundary b : boundaries) {
auto op = b.Operands()[0];
if (op == conditional_->branch_computation(0)->root_instruction()) {
continue;
}
reuses_before += ReusesBeforeBoundary(op);
VLOG(1) << "Cost of moving so far: " << reuses_before << "\n";
reuses_after += ReusesAfterBoundary(op);
VLOG(1) << "Benefit from moving so far : " << reuses_after << "\n";
}
if (reuses_after == 0 && reuses_before == 0) {
return -1;
} else if (boundaries[0].IsInsideBranch()) {
return reuses_after - reuses_before;
} else {
return reuses_before - reuses_after;
}
}
Boundary GetNextBoundary(const Boundary& b, int64 op_index) {
Boundary b2(b.GetPosition());
CHECK(b.Operands().size() == conditional_->branch_count());
for (int j = 0; j < b.Operands().size(); ++j) {
HloInstruction* inst = b.Operands()[j];
CHECK(inst != nullptr);
HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
: inst->users()[op_index];
CHECK(op != nullptr);
b2.Operands().push_back(op);
}
return b2;
}
void AddBoundaries(const Boundary& boundary) {
BoundaryVisitor visitor;
visitor.AddToWorkList(boundary);
while (visitor.HasNextBoundary()) {
Boundary b = visitor.PopNextBoundary();
// if b is already visited, it must have multiple users and is already in
// new boundaries. Skip it.
if (ContainsKey(visited_, b.Operands()[0])) {
continue;
}
VLOG(1) << "visiting boundary " << b.ToString() << "\n";
if ((b.Operands().size() == 1 ||
InstructionWithinBranchIdentical(b.Operands(),
is_layout_sensitive_)) &&
WorthHoisting(b.Operands()[0])) {
connected_boundaries_.push_back(b);
VLOG(1) << "boundary can be moved\n";
int64 operand_count = (b.IsInsideBranch())
? b.Operands()[0]->operand_count()
: b.Operands()[0]->users().size();
for (int i = 0; i < operand_count; i++) {
Boundary b2 = GetNextBoundary(b, i);
int64 b2_count = (b2.IsInsideBranch())
? b2.Operands()[0]->user_count()
: b2.Operands()[0]->operand_count();
// only consider adding an exclusive producor into the same group.
if (b2_count == 1) {
VLOG(2) << "Add operand " << i << " to visit later\n";
visitor.AddToWorkList(b2);
} else {
VLOG(2) << "Operand " << i << " has multiple uses\n";
if (!ContainsKey(visited_, b2.Operands()[0])) {
visited_.insert(b2.Operands()[0]);
new_boundaries_.push_back(b2);
}
}
}
} else {
VLOG(1) << "boundary cannot be moved\n";
visited_.insert(b.Operands()[0]);
new_boundaries_.push_back(b);
}
}
}
std::vector<Boundary> BoundariesToMoveOut(const Boundary& b) {
HloInstruction* inst = b.Operands()[0];
if (inst->opcode() == HloOpcode::kConditional) {
int branch_count = inst->branch_count();
// Visit instructions from the root instruction to the operands using BFS.
Boundary boundary_in(Boundary::Position::kInsideBranch);
for (int i = 0; i < branch_count; i++) {
HloComputation* branch_computation = inst->branch_computation(i);
HloInstruction* root_inst = branch_computation->root_instruction();
CHECK(root_inst != nullptr);
boundary_in.Operands().push_back(root_inst);
}
AddBoundaries(boundary_in);
}
return connected_boundaries_;
}
std::vector<Boundary> BoundariesToMoveIn(const Boundary& b) {
if (b.IsInsideBranch()) {
return std::vector<Boundary>();
}
AddBoundaries(b);
return connected_boundaries_;
}
std::vector<Boundary> GetNewBoundaries() { return new_boundaries_; }
};
ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
HloInstruction* conditional, const Boundary& cur_boundary,
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
auto move_out = connect.BoundariesToMoveOut(cur_boundary);
if (!move_out.empty()) {
std::vector<Boundary> next_boundaries = connect.GetNewBoundaries();
auto benefit = connect.BenefitForMovingBoundaries(move_out);
VLOG(1) << "benefit of moving " << cur_boundary.Operands()[0]->ToString()
<< ":" << benefit << "\n";
if (benefit >= 0) {
new_boundaries = next_boundaries;
to_move = move_out;
return Decision::kMoveOutOfBranch;
}
}
return ConditionalCodeMotion::Decision::kNoChange;
}
} // namespace
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
// Gather all the conditional ops in the module ahead of time, to avoid
// potential complications of modifying the code that affecting traversal.
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
bool changed = false;
if (pursue_full_conditional_code_motion_) {
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool result,
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
changed |= result;
}
if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}
}
bool changed = false;
std::vector<Boundary> to_move_out, to_move_in, new_boundaries;
for (HloInstruction* conditional : conditional_ops) {
BoundaryVisitor visitor(conditional);
VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
// Boundariess to move out of and to move into the branches.
while (visitor.HasNextBoundary()) {
std::vector<Boundary> to_move, next_boundary;
Boundary boundary = visitor.PopNextBoundary();
VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
ConditionalCodeMotion::Decision d =
ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
switch (d) {
case Decision::kMoveOutOfBranch:
VLOG(2) << "Decision is move out of branch\n";
to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end());
break;
case Decision::kMoveIntoBranch:
VLOG(2) << "Decision is move into branch\n";
to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end());
break;
case Decision::kNoChange:
VLOG(2) << "Decision is no change\n";
new_boundaries.push_back(boundary);
break;
}
for (const Boundary& b : next_boundary) {
visitor.AddToWorkList(b);
}
}
TF_ASSIGN_OR_RETURN(
bool result,
MoveInstructionOut(conditional, to_move_out, new_boundaries));
VLOG(2) << "moving out result:" << result << "\n";
changed |= result;
}
// handling convert rematerialization/hoisting
if (!changed && pursue_full_conditional_code_motion_) {
{
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
@ -762,6 +711,7 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
changed |= convert_result;
}
}
if (changed) {
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
@ -771,8 +721,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}
return changed;
}
} // namespace conditional_opt
} // namespace xla

View File

@ -23,80 +23,35 @@ limitations under the License.
namespace xla {
namespace conditional_opt {
// At the conceptural level, a boundary can be thought of as representing a
// single virtual operation, except this virtual operation is conditionally
// instantiated into different concrete operations at each conditional branch.
// So a boundary is mapped to a single concrete operation if it is outside of
// conditional branches, and is mapped to a list of instructions if inside the
// branches. This data structure therefore allows a common data structure
// representation of the instructions to be moved, whether they are inside or
// outside of the branches. Subsequently, it allows a common implementation
// basis to be used for both moving instructions out of and for moving them
// inside branches.
class Boundary {
public:
enum class Position { kInsideBranch, kOutsideBranch };
explicit Boundary(Position p) : position_(p) {}
std::vector<HloInstruction*>& Operands() { return operands_; }
const std::vector<HloInstruction*>& Operands() const { return operands_; }
bool IsInsideBranch() const { return position_ == Position::kInsideBranch; }
bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; }
Position GetPosition() const { return position_; }
bool IsEmpty() const { return operands_.empty(); }
std::string ToString() const {
std::string res;
for (HloInstruction* op : operands_) {
res += op->ToString() + ";";
}
return res;
}
private:
// Boundary instructions in the conditional branches, one from each branch
// of the conditional.
std::vector<HloInstruction*> operands_;
Position position_;
};
// HLO pass that moves identical ops in/out of conditional.
// ConditionalCodeMotion specializes in hoisting/rematerializing
// unconditional converts in the default mode.
// When pursue_full_conditional_code_motion_ is set to true, the
// full HLO pass moves identical ops out of a conditional in addition to moving
// converts.
// - The definition of identical are the shape of the operands are identical
// and their properties are identical.
// - Currently, only some types of instructions is supported.
// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in
// the new root.
// - Only the identical ops that won't share operands with other ops will
// be moved out of conditional.
class ConditionalCodeMotion : public HloModulePass {
public:
// If is_layout_sensitive is true, then the hoist process preserves layout
// during identical comparison. Otherwise, layout is ignored.
explicit ConditionalCodeMotion(bool is_layout_sensitive,
bool pursue_full_conditional_code_motion)
explicit ConditionalCodeMotion(
bool is_layout_sensitive = true,
bool pursue_full_conditional_code_motion = false)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
pursue_full_conditional_code_motion) {}
absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override;
// Optimization decision for each boundary of the conditional instruction.
enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange };
// If the optimization decision is NO_CHANGE, new_boundary is set to nullptr;
// otherwise, it is set to the new boundary after proposed optimization.
virtual Decision ConsiderCodeMotion(HloInstruction* conditional,
const Boundary& cur_boundary,
std::vector<Boundary>& to_move,
std::vector<Boundary>& new_boundaries);
private:
const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
std::vector<Boundary>& to_move_out,
std::vector<Boundary>& new_boundaries);
StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
std::vector<Boundary>& to_move_in,
std::vector<Boundary>& new_boundaries);
};
} // namespace conditional_opt
} // namespace xla

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace conditional_opt {
namespace {
using ConditionalCodeMotionTest = HloTestBase;
namespace op = xla::testing::opcode_matchers;
@ -117,47 +117,6 @@ ENTRY main {
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditional) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
absl::string_view hlo_string =
R"(
@ -193,20 +152,8 @@ ENTRY main {
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 2);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 2);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional())))),
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional()))))))));
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
}
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
@ -226,7 +173,7 @@ on_true {
add.2 = f32[] add(add.1, constant.2)
add.3 = f32[] add(add.1, constant.3)
add.4 = f32[] add(add.3, constant.5)
multiply.1 = f32[] multiply(add.4, constant.4)
multiply.1 = f32[] multiply(add.2, constant.4)
ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4)
}
@ -269,11 +216,13 @@ ENTRY main {
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 9);
// Check only one add and multiply is moved out.
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, AllOf(op::Tuple(op::Multiply(op::GetTupleElement(op::Conditional()),
op::Constant()),
op::GetTupleElement(op::Conditional()))));
root,
AllOf(op::Tuple(
op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()),
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()))));
}
TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) {
@ -320,16 +269,16 @@ ENTRY main {
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 1);
ASSERT_EQ(on_true->instruction_count(), 7);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 7);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Tuple(op::Add(
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()),
op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))));
// add.3 in on_true will be moved out, add.1 and add.2 will be in condtional
// root.
ASSERT_TRUE(ShapeUtil::Compatible(
conditional->shape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})})));
}
TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) {
@ -380,9 +329,24 @@ ENTRY main {
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
// If there is no instruction after the conditional, there is no benefit to
// move
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 9);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 9);
// Check only one add and multiply is moved out.
// add.3 and add.5 can't be moved out because they share operands with
// other instructions.
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Tuple(
op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()),
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()))));
}
TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) {
@ -505,8 +469,7 @@ ENTRY main {
false_computation=on_false
get-first-index = f32[3,3,128,128]
get-tuple-element(conditional), index=0
add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index)
ROOT result = (f32[3,3,128,128]) tuple(add.1)
ROOT result = (f32[3,3,128,128]) tuple(get-first-index)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
@ -524,14 +487,10 @@ ENTRY main {
conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(
BF16, {3, 3, 128, 128})})));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Tuple(op::Add(
op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))),
op::Convert(
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce(
op::GetTupleElement(op::Conditional()))))));
}
} // namespace conditional_opt
} // namespace
} // namespace xla