Rollback of 4abf3012f1
which accidentally breaks compilation of tensorflow/compiler/xla/service/conditional_code_motion.cc
PiperOrigin-RevId: 320840323 Change-Id: Ifaf6dd56f2d7f885a85213097d77da901b216677
This commit is contained in:
parent
f356a508cf
commit
f6984195e3
tensorflow/compiler/xla/service
@ -46,63 +46,161 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace conditional_opt {
|
||||
namespace {
|
||||
|
||||
struct ConditionalBoundary {
|
||||
ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr)
|
||||
: operand(op), operand_index(op_index), user(usr) {}
|
||||
// `operand` is one of `user`'s operand.
|
||||
|
||||
// Instruction that remains in the conditional but one of its user
|
||||
// is moved out of conditonal.
|
||||
HloInstruction* operand;
|
||||
// operand_index for `operand` in the `user`.
|
||||
int64 operand_index;
|
||||
// Instruction that moved out of conditional.
|
||||
HloInstruction* user;
|
||||
};
|
||||
|
||||
// Visit the root instructions to its operands follow BFS.
|
||||
// Will visit an instructions after all its users have been visited. Parameters
|
||||
// are not visited.
|
||||
class BoundaryVisitor {
|
||||
class BranchVisitor {
|
||||
public:
|
||||
// start with an existing conditional computation.
|
||||
explicit BoundaryVisitor(HloInstruction* conditional) {
|
||||
Boundary b(Boundary::Position::kInsideBranch);
|
||||
b.Operands().push_back(conditional);
|
||||
worklist_.push_back(b);
|
||||
explicit BranchVisitor(const HloComputation* branch_computation) {
|
||||
HloInstruction* root_inst = branch_computation->root_instruction();
|
||||
worklist_.push_back(root_inst);
|
||||
visited_.insert(root_inst);
|
||||
for (auto parameter_inst : branch_computation->parameter_instructions()) {
|
||||
parameter_instructions_.insert(parameter_inst);
|
||||
}
|
||||
}
|
||||
// Start with an empty work list.
|
||||
BoundaryVisitor() {}
|
||||
// Get next intruction to visit.
|
||||
Boundary PopNextBoundary() {
|
||||
CHECK(!worklist_.empty());
|
||||
Boundary inst = worklist_.front();
|
||||
worklist_.pop_front();
|
||||
return inst;
|
||||
}
|
||||
void AddToWorkList(const Boundary& b) {
|
||||
CHECK_GT(b.Operands().size(), 0);
|
||||
worklist_.push_back(b);
|
||||
HloInstruction* GetNextInstruction() {
|
||||
if (!worklist_.empty()) {
|
||||
HloInstruction* inst = worklist_.front();
|
||||
worklist_.pop_front();
|
||||
return inst;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool HasNextBoundary() const { return !worklist_.empty(); }
|
||||
// Add operands of one instruction to worklist for further visit.
|
||||
void AddInstructionOperands(HloInstruction* inst) {
|
||||
int64 operand_count = inst->operand_count();
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
HloInstruction* operand = inst->mutable_operand(i);
|
||||
if (ContainsKey(visited_, operand)) {
|
||||
continue;
|
||||
}
|
||||
bool all_user_visited = std::all_of(
|
||||
operand->users().begin(), operand->users().end(),
|
||||
[&](HloInstruction* user) { return ContainsKey(visited_, user); });
|
||||
|
||||
if (!all_user_visited) {
|
||||
continue;
|
||||
}
|
||||
// Do not visit parameter_instructions.
|
||||
if (ContainsKey(parameter_instructions_, operand)) {
|
||||
// Add the operand and this instruction to the boundaries.
|
||||
boundaries_.emplace_back(operand, i, inst);
|
||||
continue;
|
||||
}
|
||||
worklist_.push_back(operand);
|
||||
visited_.insert(operand);
|
||||
}
|
||||
}
|
||||
|
||||
// Add instruction and its users to conditional boundaries.
|
||||
void AddInstructionToBoundary(HloInstruction* inst) {
|
||||
for (auto user : inst->users()) {
|
||||
boundaries_.emplace_back(inst, user->operand_index(inst), user);
|
||||
}
|
||||
}
|
||||
|
||||
// Add instruction to the to be removed instructions set and vector.
|
||||
void AddInstructionToHoist(HloInstruction* inst) {
|
||||
instructions_to_hoist_set_.insert(inst);
|
||||
instructions_to_hoist_.emplace_back(inst);
|
||||
}
|
||||
|
||||
// If visitor has next instruction to visit.
|
||||
bool HasNextInstruction() const { return !worklist_.empty(); }
|
||||
|
||||
// If there is no hoist intruction.
|
||||
int64 HoistInstructionSize() { return instructions_to_hoist_.size(); }
|
||||
|
||||
// Get boundaries of this branch.
|
||||
const std::vector<ConditionalBoundary>& boundaries() const {
|
||||
return boundaries_;
|
||||
}
|
||||
|
||||
// Get instructions to hoist in this branch.
|
||||
const std::vector<HloInstruction*>& instructions_to_hoist() const {
|
||||
return instructions_to_hoist_;
|
||||
}
|
||||
|
||||
// Get hoist instruction set in this branch.
|
||||
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set() const {
|
||||
return instructions_to_hoist_set_;
|
||||
}
|
||||
|
||||
private:
|
||||
// worklist is the deque that contains instructions to be visited.
|
||||
std::deque<Boundary> worklist_;
|
||||
std::deque<HloInstruction*> worklist_;
|
||||
|
||||
// instructions that has been visited.
|
||||
std::unordered_set<HloInstruction*> visited_;
|
||||
|
||||
// parameter instructions of the branch.
|
||||
std::unordered_set<HloInstruction*> parameter_instructions_;
|
||||
|
||||
// Boundaries contains the set of instructions that its operand is within
|
||||
// conditional but it can be hoist out of conditional.
|
||||
std::vector<ConditionalBoundary> boundaries_;
|
||||
|
||||
// Instructions to hoist.
|
||||
std::unordered_set<HloInstruction*> instructions_to_hoist_set_;
|
||||
|
||||
// Instructions to hoist, the order within this vector is BFS and
|
||||
// an instruction's order will always be after its users.
|
||||
std::vector<HloInstruction*> instructions_to_hoist_;
|
||||
};
|
||||
|
||||
// Returns estimation of potential reuses carried by a given instruction.
|
||||
// Use different integers to classify different levels of reuses
|
||||
// This is used as a placeholder only, assuming all instructions can be
|
||||
// fused to enable data reuses
|
||||
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
||||
VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: "
|
||||
<< op->ToString() << "=>" << user->ToString() << "\n";
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return 0;
|
||||
default:
|
||||
break;
|
||||
// Returns true if `instruction` is worth hoisting out.
|
||||
bool WorthHoisting(HloInstruction* instruction) {
|
||||
for (const auto* operand : instruction->operands()) {
|
||||
// Only move out instructions that won't share the same operand
|
||||
// to avoid copy of the operand.
|
||||
if (operand->user_count() > 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (op->opcode()) {
|
||||
// These instructions are lightweight and easy to fuse.
|
||||
case HloOpcode::kConstant:
|
||||
return 0;
|
||||
default:
|
||||
// Assume fusion will not happen anyway if user count > 1)
|
||||
if (op->user_count() > 1) {
|
||||
return 0;
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kConvert:
|
||||
// If Convert is after AllReduce, it is worth moving out AllReduce out
|
||||
// of conditional for AR/CRS combine. If Convert is after other ops such
|
||||
// as Dot or Convolutional, it is better to keep convert within
|
||||
// conditional so that convert can be fused with Dot or Convolutional.
|
||||
//
|
||||
// TODO(b/154283721): figure out the scenario when convert can be fused
|
||||
// with AllReduce out of conditional.
|
||||
if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) {
|
||||
return true;
|
||||
}
|
||||
return 10;
|
||||
return false;
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,7 +220,7 @@ bool InstructionWithinBranchIdentical(
|
||||
return *a == *b;
|
||||
};
|
||||
|
||||
if (instructions.empty()) {
|
||||
if (instructions[0] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -150,27 +248,109 @@ bool InstructionWithinBranchIdentical(
|
||||
});
|
||||
}
|
||||
|
||||
// Copy identical instructions within conditional outside of conditional.
|
||||
Status CopyOutOfConditional(
|
||||
Boundary& boundary, HloComputation* conditional_parent,
|
||||
absl::flat_hash_map<HloInstruction*, HloInstruction*>&
|
||||
hoisted_instructions) {
|
||||
// Insert GetTupleElement before the instructions whose operands might still
|
||||
// be within the conditional.
|
||||
HloInstruction* op = boundary.Operands()[0];
|
||||
absl::InlinedVector<HloInstruction*, 4> new_operands;
|
||||
for (int i = 0; i < op->operands().size(); ++i) {
|
||||
auto op_i = op->operands()[i];
|
||||
VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n";
|
||||
CHECK(ContainsKey(hoisted_instructions, op_i));
|
||||
new_operands.push_back(FindOrDie(hoisted_instructions, op_i));
|
||||
// Returns if all the visitors/branches has next instruction to visit.
|
||||
bool HasNextInstruction(const std::vector<BranchVisitor>& visitors) {
|
||||
bool has_next = true;
|
||||
for (const auto& visitor : visitors) {
|
||||
has_next &= visitor.HasNextInstruction();
|
||||
}
|
||||
HloInstruction* new_instruction = conditional_parent->AddInstruction(
|
||||
op->CloneWithNewOperands(op->shape(), new_operands));
|
||||
// Maps the instruction outside of conditional to the instruction
|
||||
// inside of the conditional.
|
||||
for (HloInstruction* op : boundary.Operands()) {
|
||||
hoisted_instructions[op] = new_instruction;
|
||||
return has_next;
|
||||
}
|
||||
|
||||
// Create tuple element as the new root of the branch. The tuple will contain
|
||||
// the operands that can't move out of conditional but its user will be moved
|
||||
// out of conditional.
|
||||
HloInstruction* CreateNewRoot(
|
||||
const std::vector<ConditionalBoundary>& boundaries,
|
||||
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set,
|
||||
HloComputation* computation) {
|
||||
std::vector<HloInstruction*> elements;
|
||||
elements.reserve(boundaries.size());
|
||||
for (auto boundary : boundaries) {
|
||||
if (ContainsKey(instructions_to_hoist_set, boundary.user)) {
|
||||
elements.push_back(boundary.operand);
|
||||
}
|
||||
}
|
||||
return computation->AddInstruction(HloInstruction::CreateTuple(elements));
|
||||
}
|
||||
|
||||
// Copy identical instructions within conditional outside of conditional.
|
||||
void CopyIdenticalInstructionsOutOfConditional(
|
||||
const std::vector<HloInstruction*>& instructions_to_hoist,
|
||||
HloComputation* conditional_parent,
|
||||
absl::flat_hash_map<HloInstruction*, HloInstruction*>*
|
||||
hoisted_instructions) {
|
||||
int64 instructions_size = instructions_to_hoist.size();
|
||||
// Visit the operands before its users and copy it, so that the copied
|
||||
// user will point to the correct operand.
|
||||
for (int64 i = instructions_size - 1; i >= 0; i--) {
|
||||
HloInstruction* old_instruction = instructions_to_hoist[i];
|
||||
auto get_new_operand = [&](HloInstruction* old_operand) {
|
||||
// If the operand can't be found in `instructions_to_hoist`, this
|
||||
// operand will be in the `boundaries`, GetTupleElement instructions
|
||||
// will be added later to replace this operand.
|
||||
if (!ContainsKey(*hoisted_instructions, old_operand)) {
|
||||
return old_operand;
|
||||
}
|
||||
return FindOrDie(*hoisted_instructions, old_operand);
|
||||
};
|
||||
|
||||
absl::InlinedVector<HloInstruction*, 4> new_operands;
|
||||
absl::c_transform(old_instruction->operands(),
|
||||
std::back_inserter(new_operands), get_new_operand);
|
||||
|
||||
HloInstruction* new_instruction = conditional_parent->AddInstruction(
|
||||
old_instruction->CloneWithNewOperands(old_instruction->shape(),
|
||||
new_operands));
|
||||
// Maps the instruction outside of conditional to the instruction
|
||||
// inside of the conditional.
|
||||
InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
|
||||
}
|
||||
}
|
||||
|
||||
// If there are instructions to hoist, the root of the conditional must be
|
||||
// moved out. Change the users of the conditional to the hoisted instruction
|
||||
// of the new root.
|
||||
Status ChangeConditionalUsers(
|
||||
HloInstruction* conditional, HloInstruction* old_root,
|
||||
const absl::flat_hash_map<HloInstruction*, HloInstruction*>&
|
||||
hoisted_instructions) {
|
||||
HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root);
|
||||
TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Insert GetTupleElement before the instructions whose operands might still
|
||||
// be within the conditional.
|
||||
Status CreateGetTupleElementAfterConditional(
|
||||
const std::vector<ConditionalBoundary>& boundaries,
|
||||
const std::unordered_set<HloInstruction*>& instructions_to_hoist_set,
|
||||
const absl::flat_hash_map<HloInstruction*, HloInstruction*>&
|
||||
hoisted_instructions,
|
||||
HloInstruction* conditional, HloComputation* computation) {
|
||||
int boundary_instruction_size = boundaries.size();
|
||||
|
||||
// Inserts GetTupleElement before the boundary instructions.
|
||||
for (int i = 0; i < boundary_instruction_size; i++) {
|
||||
HloInstruction* gte =
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
boundaries[i].operand->shape(), conditional, i));
|
||||
|
||||
HloInstruction* new_instruction =
|
||||
FindOrDie(hoisted_instructions, boundaries[i].user);
|
||||
TF_RETURN_IF_ERROR(
|
||||
new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Remove instructions to be hoisted out of the branch computation.
|
||||
Status RemoveInstructionFromComputation(
|
||||
const std::vector<HloInstruction*>& instructions_to_hoist,
|
||||
HloComputation* branch) {
|
||||
// Will visit the instructions after its users.
|
||||
for (auto* instruction : instructions_to_hoist) {
|
||||
TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -394,359 +574,128 @@ StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
// are the shape of the operands are identical and their properties are
|
||||
// identical. Will start from the root instruction of each branch and get
|
||||
// the identical ops to hoist.
|
||||
StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
||||
HloInstruction* conditional, std::vector<Boundary>& to_move_out,
|
||||
std::vector<Boundary>& new_boundaries) {
|
||||
if (to_move_out.empty()) {
|
||||
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
VLOG(1) << " visiting conditional:" << conditional->ToString();
|
||||
int branch_count = conditional->branch_count();
|
||||
if (branch_count <= 0) {
|
||||
return false;
|
||||
}
|
||||
VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n";
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
// save the old users before add new conditional user instructions
|
||||
std::vector<HloInstruction*> old_conditional_users = conditional->users();
|
||||
absl::flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
|
||||
// Maps instructions in the conditional body to instructions hoisted outside
|
||||
// the conditional that compute the same value.
|
||||
VLOG(2) << "before opt:"
|
||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||
<< "\n";
|
||||
int64 op_index = 0;
|
||||
for (Boundary b : new_boundaries) {
|
||||
HloInstruction* op = b.Operands()[0];
|
||||
CHECK(op != nullptr);
|
||||
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
|
||||
HloInstruction* gtr = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(op->shape(), conditional,
|
||||
op_index++));
|
||||
hoisted_instructions[op] = gtr;
|
||||
|
||||
std::vector<BranchVisitor> visitors;
|
||||
visitors.reserve(branch_count);
|
||||
// Visit instructions from the root instruction to the operands using BFS.
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
visitors.emplace_back(BranchVisitor(conditional->branch_computation(i)));
|
||||
}
|
||||
// Copy boundary instructions out of the conditional.
|
||||
// Visit the operands before its users and copy it, so that the copied
|
||||
// user will point to the correct operand.
|
||||
for (int64 i = to_move_out.size() - 1; i >= 0; i--) {
|
||||
TF_RETURN_IF_ERROR(CopyOutOfConditional(to_move_out[i], conditional_parent,
|
||||
hoisted_instructions));
|
||||
|
||||
// The instructions to be visited within each branch.
|
||||
std::vector<HloInstruction*> front_instructions(branch_count);
|
||||
|
||||
while (HasNextInstruction(visitors)) {
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
front_instructions[i] = visitors[i].GetNextInstruction();
|
||||
}
|
||||
// If two instructions has the same shape, opcode and its operands has the
|
||||
// same shape, then this instruction can be moved out of conditional.
|
||||
if (WorthHoisting(front_instructions[0]) &&
|
||||
InstructionWithinBranchIdentical(front_instructions,
|
||||
is_layout_sensitive)) {
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
visitors[i].AddInstructionOperands(front_instructions[i]);
|
||||
visitors[i].AddInstructionToHoist(front_instructions[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
// If the ops are not identical, these ops and its users will
|
||||
// be in the boundaries` of the conditional. These ops will be stayed
|
||||
// within the conditional, but one its only user will be moved out
|
||||
// of conditional.
|
||||
visitors[i].AddInstructionToBoundary(front_instructions[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Done copy branch instructions out\n"
|
||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||
<< "\n";
|
||||
// Change original users of the conditional to use the correct operands.
|
||||
|
||||
if (visitors[0].HoistInstructionSize() < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
for (auto user_instr : old_conditional_users) {
|
||||
CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
|
||||
auto tuple_opd = down_cast<HloGetTupleElementInstruction*>(user_instr);
|
||||
int64 index = tuple_opd->tuple_index();
|
||||
HloInstruction* old_opd = old_root->operands()[index];
|
||||
HloInstruction* new_opd = hoisted_instructions[old_opd];
|
||||
CHECK(old_opd != nullptr);
|
||||
CHECK(new_opd != nullptr);
|
||||
TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
|
||||
TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
|
||||
}
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
// Maps instructions in the conditional body to instructions hoisted outside
|
||||
// the conditional that compute the same value.
|
||||
absl::flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
|
||||
// Copy identical instructions out of the conditional.
|
||||
CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(),
|
||||
conditional_parent,
|
||||
&hoisted_instructions);
|
||||
// If there are instructions to hoist, the root of the conditional must be
|
||||
// moved out. Change the users of the conditional to the hoisted instruction
|
||||
// of the new root.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeConditionalUsers(conditional, old_root, hoisted_instructions));
|
||||
|
||||
// Create tuple element within each branch and set it as root.
|
||||
int64 branch_count = conditional->branch_count();
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
auto computation = conditional->branch_computation(i);
|
||||
std::vector<HloInstruction*> elements;
|
||||
for (auto b1 : new_boundaries) {
|
||||
HloInstruction* op = b1.Operands()[i];
|
||||
VLOG(1) << "branch count=" << i << "\n";
|
||||
CHECK(op != nullptr);
|
||||
VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n";
|
||||
elements.push_back(op);
|
||||
}
|
||||
HloInstruction* tuple =
|
||||
computation->AddInstruction(HloInstruction::CreateTuple(elements));
|
||||
computation->set_root_instruction(tuple, true);
|
||||
VLOG(2) << "computation is :" << computation->ToString() << "\n";
|
||||
// Remove hoisted instructions from the branches.
|
||||
for (auto b2 : to_move_out) {
|
||||
VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
|
||||
TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.Operands()[i]));
|
||||
}
|
||||
HloInstruction* tuple = CreateNewRoot(
|
||||
visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(),
|
||||
conditional->branch_computation(i));
|
||||
conditional->branch_computation(i)->set_root_instruction(tuple, true);
|
||||
}
|
||||
// Changes conditional instruction shape to the shape of the new root.
|
||||
*conditional->mutable_shape() =
|
||||
conditional->branch_computation(0)->root_instruction()->shape();
|
||||
|
||||
// Insert GetTupleElement before the instructions whose operands might still
|
||||
// be within the conditional.
|
||||
TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional(
|
||||
visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(),
|
||||
hoisted_instructions, conditional, conditional_parent));
|
||||
|
||||
// Remove hoist instructions from the branches.
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
|
||||
conditional->branch_computation(i)));
|
||||
}
|
||||
// Change conditional instruction shape to the shape of the new root.
|
||||
HloInstruction* new_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
*conditional->mutable_shape() = new_root->shape();
|
||||
//
|
||||
VLOG(2) << "done moving instructions out of branches\n"
|
||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||
<< "\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
// Group single chains of operands or uses of boundaries into new boundaries
|
||||
class GroupConnectedBoundaries {
|
||||
private:
|
||||
std::unordered_set<HloInstruction*> visited_;
|
||||
std::vector<Boundary> connected_boundaries_, new_boundaries_;
|
||||
HloInstruction* conditional_;
|
||||
bool is_layout_sensitive_;
|
||||
|
||||
public:
|
||||
explicit GroupConnectedBoundaries(HloInstruction* conditional,
|
||||
bool is_layout_sensitive)
|
||||
: conditional_(conditional), is_layout_sensitive_(is_layout_sensitive) {}
|
||||
// Returns true if `instruction` is worth hoisting out.
|
||||
bool WorthHoisting(HloInstruction* instruction) {
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kConvert:
|
||||
// If Convert is after AllReduce, it is worth moving out AllReduce out
|
||||
// of conditional for AR/CRS combine. If Convert is after other ops such
|
||||
// as Dot or Convolutional, it is better to keep convert within
|
||||
// conditional so that convert can be fused with Dot or Convolutional.
|
||||
//
|
||||
// TODO(b/154283721): figure out the scenario when convert can be fused
|
||||
// with AllReduce out of conditional.
|
||||
switch (instruction->operand(0)->opcode()) {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kReshape:
|
||||
return true;
|
||||
default:
|
||||
VLOG(1) << "Instruction is convert and its operand is not know to "
|
||||
"be worth hoisting\n";
|
||||
return false;
|
||||
}
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return true;
|
||||
default:
|
||||
VLOG(1) << "Instruction is not known to be worth hoisting\n";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Calculates the degree of reuses carried by a pair of conditional
|
||||
// boundaries, if b1 is inside a conditional and b2 is outside.
|
||||
int64 ReusesBeforeBoundary(HloInstruction* user) {
|
||||
int64 reuses = 0;
|
||||
for (auto op : user->operands()) {
|
||||
// Only consider single-user cases as reuseable.
|
||||
if (ContainsKey(visited_, op) && op->user_count() == 1) {
|
||||
reuses += ReusesCarriedBy(op, user);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "cost to be paied after moving out" << user->ToString() << ":"
|
||||
<< reuses << "\n";
|
||||
return reuses;
|
||||
}
|
||||
|
||||
int64 ReusesAfterBoundary(HloInstruction* user) {
|
||||
CHECK(user != nullptr);
|
||||
auto all_users = user->users();
|
||||
// For now, assume that if an instruction has multiple-consumers, it will
|
||||
// not be reused (the reuse currently requires duplication in fusion and so
|
||||
// is expensive).
|
||||
if (all_users.size() > 1) {
|
||||
return 0;
|
||||
}
|
||||
if (!all_users.empty()) {
|
||||
auto op = all_users[0];
|
||||
int64 reuses = 0;
|
||||
// Only count reuses that run through the conditional root.
|
||||
if (op == conditional_->branch_computation(0)->root_instruction()) {
|
||||
int64 index = op->operand_index(user);
|
||||
for (auto op2 : conditional_->users()) {
|
||||
CHECK(op2->opcode() == HloOpcode::kGetTupleElement);
|
||||
auto tuple_opd = down_cast<HloGetTupleElementInstruction*>(op2);
|
||||
if (index == tuple_opd->tuple_index()) {
|
||||
all_users = op2->users();
|
||||
if (!all_users.empty()) {
|
||||
reuses += ReusesCarriedBy(user, all_users[0]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(1) << "reuses to be gained after moving " << user->ToString() << ":"
|
||||
<< reuses << "\n";
|
||||
return reuses;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
|
||||
int64 reuses_before = 0, reuses_after = 0;
|
||||
for (Boundary b : boundaries) {
|
||||
auto op = b.Operands()[0];
|
||||
if (op == conditional_->branch_computation(0)->root_instruction()) {
|
||||
continue;
|
||||
}
|
||||
reuses_before += ReusesBeforeBoundary(op);
|
||||
VLOG(1) << "Cost of moving so far: " << reuses_before << "\n";
|
||||
reuses_after += ReusesAfterBoundary(op);
|
||||
VLOG(1) << "Benefit from moving so far : " << reuses_after << "\n";
|
||||
}
|
||||
if (reuses_after == 0 && reuses_before == 0) {
|
||||
return -1;
|
||||
} else if (boundaries[0].IsInsideBranch()) {
|
||||
return reuses_after - reuses_before;
|
||||
} else {
|
||||
return reuses_before - reuses_after;
|
||||
}
|
||||
}
|
||||
|
||||
Boundary GetNextBoundary(const Boundary& b, int64 op_index) {
|
||||
Boundary b2(b.GetPosition());
|
||||
CHECK(b.Operands().size() == conditional_->branch_count());
|
||||
for (int j = 0; j < b.Operands().size(); ++j) {
|
||||
HloInstruction* inst = b.Operands()[j];
|
||||
CHECK(inst != nullptr);
|
||||
HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
|
||||
: inst->users()[op_index];
|
||||
CHECK(op != nullptr);
|
||||
b2.Operands().push_back(op);
|
||||
}
|
||||
return b2;
|
||||
}
|
||||
void AddBoundaries(const Boundary& boundary) {
|
||||
BoundaryVisitor visitor;
|
||||
visitor.AddToWorkList(boundary);
|
||||
while (visitor.HasNextBoundary()) {
|
||||
Boundary b = visitor.PopNextBoundary();
|
||||
// if b is already visited, it must have multiple users and is already in
|
||||
// new boundaries. Skip it.
|
||||
if (ContainsKey(visited_, b.Operands()[0])) {
|
||||
continue;
|
||||
}
|
||||
VLOG(1) << "visiting boundary " << b.ToString() << "\n";
|
||||
if ((b.Operands().size() == 1 ||
|
||||
InstructionWithinBranchIdentical(b.Operands(),
|
||||
is_layout_sensitive_)) &&
|
||||
WorthHoisting(b.Operands()[0])) {
|
||||
connected_boundaries_.push_back(b);
|
||||
VLOG(1) << "boundary can be moved\n";
|
||||
int64 operand_count = (b.IsInsideBranch())
|
||||
? b.Operands()[0]->operand_count()
|
||||
: b.Operands()[0]->users().size();
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
Boundary b2 = GetNextBoundary(b, i);
|
||||
int64 b2_count = (b2.IsInsideBranch())
|
||||
? b2.Operands()[0]->user_count()
|
||||
: b2.Operands()[0]->operand_count();
|
||||
// only consider adding an exclusive producor into the same group.
|
||||
if (b2_count == 1) {
|
||||
VLOG(2) << "Add operand " << i << " to visit later\n";
|
||||
visitor.AddToWorkList(b2);
|
||||
} else {
|
||||
VLOG(2) << "Operand " << i << " has multiple uses\n";
|
||||
if (!ContainsKey(visited_, b2.Operands()[0])) {
|
||||
visited_.insert(b2.Operands()[0]);
|
||||
new_boundaries_.push_back(b2);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "boundary cannot be moved\n";
|
||||
visited_.insert(b.Operands()[0]);
|
||||
new_boundaries_.push_back(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<Boundary> BoundariesToMoveOut(const Boundary& b) {
|
||||
HloInstruction* inst = b.Operands()[0];
|
||||
if (inst->opcode() == HloOpcode::kConditional) {
|
||||
int branch_count = inst->branch_count();
|
||||
// Visit instructions from the root instruction to the operands using BFS.
|
||||
Boundary boundary_in(Boundary::Position::kInsideBranch);
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
HloComputation* branch_computation = inst->branch_computation(i);
|
||||
HloInstruction* root_inst = branch_computation->root_instruction();
|
||||
CHECK(root_inst != nullptr);
|
||||
boundary_in.Operands().push_back(root_inst);
|
||||
}
|
||||
AddBoundaries(boundary_in);
|
||||
}
|
||||
return connected_boundaries_;
|
||||
}
|
||||
std::vector<Boundary> BoundariesToMoveIn(const Boundary& b) {
|
||||
if (b.IsInsideBranch()) {
|
||||
return std::vector<Boundary>();
|
||||
}
|
||||
AddBoundaries(b);
|
||||
return connected_boundaries_;
|
||||
}
|
||||
std::vector<Boundary> GetNewBoundaries() { return new_boundaries_; }
|
||||
};
|
||||
|
||||
ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
||||
HloInstruction* conditional, const Boundary& cur_boundary,
|
||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
|
||||
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
|
||||
auto move_out = connect.BoundariesToMoveOut(cur_boundary);
|
||||
if (!move_out.empty()) {
|
||||
std::vector<Boundary> next_boundaries = connect.GetNewBoundaries();
|
||||
auto benefit = connect.BenefitForMovingBoundaries(move_out);
|
||||
VLOG(1) << "benefit of moving " << cur_boundary.Operands()[0]->ToString()
|
||||
<< ":" << benefit << "\n";
|
||||
if (benefit >= 0) {
|
||||
new_boundaries = next_boundaries;
|
||||
to_move = move_out;
|
||||
return Decision::kMoveOutOfBranch;
|
||||
}
|
||||
}
|
||||
return ConditionalCodeMotion::Decision::kNoChange;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
// Gather all the conditional ops in the module ahead of time, to avoid
|
||||
// potential complications of modifying the code that affecting traversal.
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
bool changed = false;
|
||||
|
||||
if (pursue_full_conditional_code_motion_) {
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
changed |= cleanup_changed;
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<Boundary> to_move_out, to_move_in, new_boundaries;
|
||||
for (HloInstruction* conditional : conditional_ops) {
|
||||
BoundaryVisitor visitor(conditional);
|
||||
VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
|
||||
// Boundariess to move out of and to move into the branches.
|
||||
while (visitor.HasNextBoundary()) {
|
||||
std::vector<Boundary> to_move, next_boundary;
|
||||
Boundary boundary = visitor.PopNextBoundary();
|
||||
VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
|
||||
ConditionalCodeMotion::Decision d =
|
||||
ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
|
||||
switch (d) {
|
||||
case Decision::kMoveOutOfBranch:
|
||||
VLOG(2) << "Decision is move out of branch\n";
|
||||
to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end());
|
||||
break;
|
||||
case Decision::kMoveIntoBranch:
|
||||
VLOG(2) << "Decision is move into branch\n";
|
||||
to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end());
|
||||
break;
|
||||
case Decision::kNoChange:
|
||||
VLOG(2) << "Decision is no change\n";
|
||||
new_boundaries.push_back(boundary);
|
||||
break;
|
||||
}
|
||||
for (const Boundary& b : next_boundary) {
|
||||
visitor.AddToWorkList(b);
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MoveInstructionOut(conditional, to_move_out, new_boundaries));
|
||||
VLOG(2) << "moving out result:" << result << "\n";
|
||||
changed |= result;
|
||||
}
|
||||
// handling convert rematerialization/hoisting
|
||||
if (!changed && pursue_full_conditional_code_motion_) {
|
||||
{
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
@ -762,6 +711,7 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
changed |= convert_result;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline(
|
||||
"after_conditional_code_motion_after_convert_hoisting");
|
||||
@ -771,8 +721,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
changed |= cleanup_changed;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,80 +23,35 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace conditional_opt {
|
||||
// At the conceptural level, a boundary can be thought of as representing a
|
||||
// single virtual operation, except this virtual operation is conditionally
|
||||
// instantiated into different concrete operations at each conditional branch.
|
||||
// So a boundary is mapped to a single concrete operation if it is outside of
|
||||
// conditional branches, and is mapped to a list of instructions if inside the
|
||||
// branches. This data structure therefore allows a common data structure
|
||||
// representation of the instructions to be moved, whether they are inside or
|
||||
// outside of the branches. Subsequently, it allows a common implementation
|
||||
// basis to be used for both moving instructions out of and for moving them
|
||||
// inside branches.
|
||||
class Boundary {
|
||||
public:
|
||||
enum class Position { kInsideBranch, kOutsideBranch };
|
||||
explicit Boundary(Position p) : position_(p) {}
|
||||
std::vector<HloInstruction*>& Operands() { return operands_; }
|
||||
const std::vector<HloInstruction*>& Operands() const { return operands_; }
|
||||
bool IsInsideBranch() const { return position_ == Position::kInsideBranch; }
|
||||
bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; }
|
||||
Position GetPosition() const { return position_; }
|
||||
bool IsEmpty() const { return operands_.empty(); }
|
||||
std::string ToString() const {
|
||||
std::string res;
|
||||
for (HloInstruction* op : operands_) {
|
||||
res += op->ToString() + ";";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
// Boundary instructions in the conditional branches, one from each branch
|
||||
// of the conditional.
|
||||
std::vector<HloInstruction*> operands_;
|
||||
Position position_;
|
||||
};
|
||||
|
||||
// HLO pass that moves identical ops in/out of conditional.
|
||||
// ConditionalCodeMotion specializes in hoisting/rematerializing
|
||||
// unconditional converts in the default mode.
|
||||
// When pursue_full_conditional_code_motion_ is set to true, the
|
||||
// full HLO pass moves identical ops out of a conditional in addition to moving
|
||||
// converts.
|
||||
// - The definition of identical are the shape of the operands are identical
|
||||
// and their properties are identical.
|
||||
// - Currently, only some types of instructions is supported.
|
||||
// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in
|
||||
// the new root.
|
||||
// - Only the identical ops that won't share operands with other ops will
|
||||
// be moved out of conditional.
|
||||
class ConditionalCodeMotion : public HloModulePass {
|
||||
public:
|
||||
// If is_layout_sensitive is true, then the hoist process preserves layout
|
||||
// during identical comparison. Otherwise, layout is ignored.
|
||||
explicit ConditionalCodeMotion(bool is_layout_sensitive,
|
||||
bool pursue_full_conditional_code_motion)
|
||||
explicit ConditionalCodeMotion(
|
||||
bool is_layout_sensitive = true,
|
||||
bool pursue_full_conditional_code_motion = false)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
pursue_full_conditional_code_motion_(
|
||||
pursue_full_conditional_code_motion) {}
|
||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Optimization decision for each boundary of the conditional instruction.
|
||||
enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange };
|
||||
// If the optimization decision is NO_CHANGE, new_boundary is set to nullptr;
|
||||
// otherwise, it is set to the new boundary after proposed optimization.
|
||||
virtual Decision ConsiderCodeMotion(HloInstruction* conditional,
|
||||
const Boundary& cur_boundary,
|
||||
std::vector<Boundary>& to_move,
|
||||
std::vector<Boundary>& new_boundaries);
|
||||
|
||||
private:
|
||||
const bool is_layout_sensitive_;
|
||||
const bool pursue_full_conditional_code_motion_;
|
||||
|
||||
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
|
||||
std::vector<Boundary>& to_move_out,
|
||||
std::vector<Boundary>& new_boundaries);
|
||||
StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
|
||||
std::vector<Boundary>& to_move_in,
|
||||
std::vector<Boundary>& new_boundaries);
|
||||
};
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace conditional_opt {
|
||||
namespace {
|
||||
|
||||
using ConditionalCodeMotionTest = HloTestBase;
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
@ -117,47 +117,6 @@ ENTRY main {
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditional) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
@ -193,20 +152,8 @@ ENTRY main {
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 2);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 2);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional())))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional()))))))));
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
||||
@ -226,7 +173,7 @@ on_true {
|
||||
add.2 = f32[] add(add.1, constant.2)
|
||||
add.3 = f32[] add(add.1, constant.3)
|
||||
add.4 = f32[] add(add.3, constant.5)
|
||||
multiply.1 = f32[] multiply(add.4, constant.4)
|
||||
multiply.1 = f32[] multiply(add.2, constant.4)
|
||||
ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4)
|
||||
}
|
||||
|
||||
@ -269,11 +216,13 @@ ENTRY main {
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 9);
|
||||
|
||||
// Check only one add and multiply is moved out.
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root, AllOf(op::Tuple(op::Multiply(op::GetTupleElement(op::Conditional()),
|
||||
op::Constant()),
|
||||
op::GetTupleElement(op::Conditional()))));
|
||||
root,
|
||||
AllOf(op::Tuple(
|
||||
op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()),
|
||||
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) {
|
||||
@ -320,16 +269,16 @@ ENTRY main {
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
ASSERT_EQ(on_true->instruction_count(), 7);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 7);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(op::Add(
|
||||
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()),
|
||||
op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))));
|
||||
// add.3 in on_true will be moved out, add.1 and add.2 will be in condtional
|
||||
// root.
|
||||
ASSERT_TRUE(ShapeUtil::Compatible(
|
||||
conditional->shape(),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})})));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) {
|
||||
@ -380,9 +329,24 @@ ENTRY main {
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
// If there is no instruction after the conditional, there is no benefit to
|
||||
// move
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 9);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 9);
|
||||
|
||||
// Check only one add and multiply is moved out.
|
||||
// add.3 and add.5 can't be moved out because they share operands with
|
||||
// other instructions.
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(
|
||||
op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()),
|
||||
op::Add(op::GetTupleElement(op::Conditional()), op::Constant()))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) {
|
||||
@ -505,8 +469,7 @@ ENTRY main {
|
||||
false_computation=on_false
|
||||
get-first-index = f32[3,3,128,128]
|
||||
get-tuple-element(conditional), index=0
|
||||
add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index)
|
||||
ROOT result = (f32[3,3,128,128]) tuple(add.1)
|
||||
ROOT result = (f32[3,3,128,128]) tuple(get-first-index)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
@ -524,14 +487,10 @@ ENTRY main {
|
||||
conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(
|
||||
BF16, {3, 3, 128, 128})})));
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(op::Add(
|
||||
op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))),
|
||||
op::Convert(
|
||||
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce(
|
||||
op::GetTupleElement(op::Conditional()))))));
|
||||
}
|
||||
|
||||
} // namespace conditional_opt
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user