From f6984195e3ea91f05d329725086eb2af4043b07d Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Sun, 12 Jul 2020 07:39:22 -0700 Subject: [PATCH] Rollback of 4abf3012f14b76aa0fd1ce987f9afdad7639d734 which accidentally breaks compilation of `tensorflow/compiler/xla/service/conditional_code_motion.cc` PiperOrigin-RevId: 320840323 Change-Id: Ifaf6dd56f2d7f885a85213097d77da901b216677 --- .../xla/service/conditional_code_motion.cc | 744 ++++++++---------- .../xla/service/conditional_code_motion.h | 67 +- .../service/conditional_code_motion_test.cc | 117 +-- 3 files changed, 396 insertions(+), 532 deletions(-) diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index 3287726b1a7..6db4c3eb6d4 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -46,63 +46,161 @@ limitations under the License. namespace xla { -namespace conditional_opt { +namespace { + +struct ConditionalBoundary { + ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr) + : operand(op), operand_index(op_index), user(usr) {} + // `operand` is one of `user`'s operand. + + // Instruction that remains in the conditional but one of its user + // is moved out of conditonal. + HloInstruction* operand; + // operand_index for `operand` in the `user`. + int64 operand_index; + // Instruction that moved out of conditional. + HloInstruction* user; +}; // Visit the root instructions to its operands follow BFS. // Will visit an instructions after all its users have been visited. Parameters // are not visited. -class BoundaryVisitor { +class BranchVisitor { public: - // start with an existing conditional computation. - explicit BoundaryVisitor(HloInstruction* conditional) { - Boundary b(Boundary::Position::kInsideBranch); - b.Operands().push_back(conditional); - worklist_.push_back(b); + explicit BranchVisitor(const HloComputation* branch_computation) { + HloInstruction* root_inst = branch_computation->root_instruction(); + worklist_.push_back(root_inst); + visited_.insert(root_inst); + for (auto parameter_inst : branch_computation->parameter_instructions()) { + parameter_instructions_.insert(parameter_inst); + } } - // Start with an empty work list. - BoundaryVisitor() {} // Get next intruction to visit. - Boundary PopNextBoundary() { - CHECK(!worklist_.empty()); - Boundary inst = worklist_.front(); - worklist_.pop_front(); - return inst; - } - void AddToWorkList(const Boundary& b) { - CHECK_GT(b.Operands().size(), 0); - worklist_.push_back(b); + HloInstruction* GetNextInstruction() { + if (!worklist_.empty()) { + HloInstruction* inst = worklist_.front(); + worklist_.pop_front(); + return inst; + } + return nullptr; } - bool HasNextBoundary() const { return !worklist_.empty(); } + // Add operands of one instruction to worklist for further visit. + void AddInstructionOperands(HloInstruction* inst) { + int64 operand_count = inst->operand_count(); + for (int i = 0; i < operand_count; i++) { + HloInstruction* operand = inst->mutable_operand(i); + if (ContainsKey(visited_, operand)) { + continue; + } + bool all_user_visited = std::all_of( + operand->users().begin(), operand->users().end(), + [&](HloInstruction* user) { return ContainsKey(visited_, user); }); + + if (!all_user_visited) { + continue; + } + // Do not visit parameter_instructions. + if (ContainsKey(parameter_instructions_, operand)) { + // Add the operand and this instruction to the boundaries. + boundaries_.emplace_back(operand, i, inst); + continue; + } + worklist_.push_back(operand); + visited_.insert(operand); + } + } + + // Add instruction and its users to conditional boundaries. + void AddInstructionToBoundary(HloInstruction* inst) { + for (auto user : inst->users()) { + boundaries_.emplace_back(inst, user->operand_index(inst), user); + } + } + + // Add instruction to the to be removed instructions set and vector. + void AddInstructionToHoist(HloInstruction* inst) { + instructions_to_hoist_set_.insert(inst); + instructions_to_hoist_.emplace_back(inst); + } + + // If visitor has next instruction to visit. + bool HasNextInstruction() const { return !worklist_.empty(); } + + // If there is no hoist intruction. + int64 HoistInstructionSize() { return instructions_to_hoist_.size(); } + + // Get boundaries of this branch. + const std::vector& boundaries() const { + return boundaries_; + } + + // Get instructions to hoist in this branch. + const std::vector& instructions_to_hoist() const { + return instructions_to_hoist_; + } + + // Get hoist instruction set in this branch. + const std::unordered_set& instructions_to_hoist_set() const { + return instructions_to_hoist_set_; + } private: // worklist is the deque that contains instructions to be visited. - std::deque worklist_; + std::deque worklist_; + + // instructions that has been visited. + std::unordered_set visited_; + + // parameter instructions of the branch. + std::unordered_set parameter_instructions_; + + // Boundaries contains the set of instructions that its operand is within + // conditional but it can be hoist out of conditional. + std::vector boundaries_; + + // Instructions to hoist. + std::unordered_set instructions_to_hoist_set_; + + // Instructions to hoist, the order within this vector is BFS and + // an instruction's order will always be after its users. + std::vector instructions_to_hoist_; }; -// Returns estimation of potential reuses carried by a given instruction. -// Use different integers to classify different levels of reuses -// This is used as a placeholder only, assuming all instructions can be -// fused to enable data reuses -int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { - VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: " - << op->ToString() << "=>" << user->ToString() << "\n"; - switch (user->opcode()) { - case HloOpcode::kGetTupleElement: - return 0; - default: - break; +// Returns true if `instruction` is worth hoisting out. +bool WorthHoisting(HloInstruction* instruction) { + for (const auto* operand : instruction->operands()) { + // Only move out instructions that won't share the same operand + // to avoid copy of the operand. + if (operand->user_count() > 1) { + return false; + } } - switch (op->opcode()) { - // These instructions are lightweight and easy to fuse. - case HloOpcode::kConstant: - return 0; - default: - // Assume fusion will not happen anyway if user count > 1) - if (op->user_count() > 1) { - return 0; + switch (instruction->opcode()) { + case HloOpcode::kConvert: + // If Convert is after AllReduce, it is worth moving out AllReduce out + // of conditional for AR/CRS combine. If Convert is after other ops such + // as Dot or Convolutional, it is better to keep convert within + // conditional so that convert can be fused with Dot or Convolutional. + // + // TODO(b/154283721): figure out the scenario when convert can be fused + // with AllReduce out of conditional. + if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) { + return true; } - return 10; + return false; + case HloOpcode::kAllReduce: + case HloOpcode::kAdd: + case HloOpcode::kConstant: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + case HloOpcode::kDivide: + case HloOpcode::kTuple: + case HloOpcode::kSqrt: + case HloOpcode::kGetTupleElement: + return true; + default: + return false; } } @@ -122,7 +220,7 @@ bool InstructionWithinBranchIdentical( return *a == *b; }; - if (instructions.empty()) { + if (instructions[0] == nullptr) { return false; } @@ -150,27 +248,109 @@ bool InstructionWithinBranchIdentical( }); } -// Copy identical instructions within conditional outside of conditional. -Status CopyOutOfConditional( - Boundary& boundary, HloComputation* conditional_parent, - absl::flat_hash_map& - hoisted_instructions) { - // Insert GetTupleElement before the instructions whose operands might still - // be within the conditional. - HloInstruction* op = boundary.Operands()[0]; - absl::InlinedVector new_operands; - for (int i = 0; i < op->operands().size(); ++i) { - auto op_i = op->operands()[i]; - VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n"; - CHECK(ContainsKey(hoisted_instructions, op_i)); - new_operands.push_back(FindOrDie(hoisted_instructions, op_i)); +// Returns if all the visitors/branches has next instruction to visit. +bool HasNextInstruction(const std::vector& visitors) { + bool has_next = true; + for (const auto& visitor : visitors) { + has_next &= visitor.HasNextInstruction(); } - HloInstruction* new_instruction = conditional_parent->AddInstruction( - op->CloneWithNewOperands(op->shape(), new_operands)); - // Maps the instruction outside of conditional to the instruction - // inside of the conditional. - for (HloInstruction* op : boundary.Operands()) { - hoisted_instructions[op] = new_instruction; + return has_next; +} + +// Create tuple element as the new root of the branch. The tuple will contain +// the operands that can't move out of conditional but its user will be moved +// out of conditional. +HloInstruction* CreateNewRoot( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + HloComputation* computation) { + std::vector elements; + elements.reserve(boundaries.size()); + for (auto boundary : boundaries) { + if (ContainsKey(instructions_to_hoist_set, boundary.user)) { + elements.push_back(boundary.operand); + } + } + return computation->AddInstruction(HloInstruction::CreateTuple(elements)); +} + +// Copy identical instructions within conditional outside of conditional. +void CopyIdenticalInstructionsOutOfConditional( + const std::vector& instructions_to_hoist, + HloComputation* conditional_parent, + absl::flat_hash_map* + hoisted_instructions) { + int64 instructions_size = instructions_to_hoist.size(); + // Visit the operands before its users and copy it, so that the copied + // user will point to the correct operand. + for (int64 i = instructions_size - 1; i >= 0; i--) { + HloInstruction* old_instruction = instructions_to_hoist[i]; + auto get_new_operand = [&](HloInstruction* old_operand) { + // If the operand can't be found in `instructions_to_hoist`, this + // operand will be in the `boundaries`, GetTupleElement instructions + // will be added later to replace this operand. + if (!ContainsKey(*hoisted_instructions, old_operand)) { + return old_operand; + } + return FindOrDie(*hoisted_instructions, old_operand); + }; + + absl::InlinedVector new_operands; + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); + + HloInstruction* new_instruction = conditional_parent->AddInstruction( + old_instruction->CloneWithNewOperands(old_instruction->shape(), + new_operands)); + // Maps the instruction outside of conditional to the instruction + // inside of the conditional. + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + } +} + +// If there are instructions to hoist, the root of the conditional must be +// moved out. Change the users of the conditional to the hoisted instruction +// of the new root. +Status ChangeConditionalUsers( + HloInstruction* conditional, HloInstruction* old_root, + const absl::flat_hash_map& + hoisted_instructions) { + HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root); + TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root)); + return Status::OK(); +} + +// Insert GetTupleElement before the instructions whose operands might still +// be within the conditional. +Status CreateGetTupleElementAfterConditional( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + const absl::flat_hash_map& + hoisted_instructions, + HloInstruction* conditional, HloComputation* computation) { + int boundary_instruction_size = boundaries.size(); + + // Inserts GetTupleElement before the boundary instructions. + for (int i = 0; i < boundary_instruction_size; i++) { + HloInstruction* gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + boundaries[i].operand->shape(), conditional, i)); + + HloInstruction* new_instruction = + FindOrDie(hoisted_instructions, boundaries[i].user); + TF_RETURN_IF_ERROR( + new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte)); + } + return Status::OK(); +} + +// Remove instructions to be hoisted out of the branch computation. +Status RemoveInstructionFromComputation( + const std::vector& instructions_to_hoist, + HloComputation* branch) { + // Will visit the instructions after its users. + for (auto* instruction : instructions_to_hoist) { + TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction)); } return Status::OK(); } @@ -394,359 +574,128 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, // are the shape of the operands are identical and their properties are // identical. Will start from the root instruction of each branch and get // the identical ops to hoist. -StatusOr ConditionalCodeMotion::MoveInstructionOut( - HloInstruction* conditional, std::vector& to_move_out, - std::vector& new_boundaries) { - if (to_move_out.empty()) { +StatusOr MergeIdenticalElements(HloInstruction* conditional, + bool is_layout_sensitive) { + VLOG(1) << " visiting conditional:" << conditional->ToString(); + int branch_count = conditional->branch_count(); + if (branch_count <= 0) { return false; } - VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n"; - HloComputation* conditional_parent = conditional->parent(); - // save the old users before add new conditional user instructions - std::vector old_conditional_users = conditional->users(); - absl::flat_hash_map hoisted_instructions; - // Maps instructions in the conditional body to instructions hoisted outside - // the conditional that compute the same value. - VLOG(2) << "before opt:" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) - << "\n"; - int64 op_index = 0; - for (Boundary b : new_boundaries) { - HloInstruction* op = b.Operands()[0]; - CHECK(op != nullptr); - VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; - HloInstruction* gtr = conditional_parent->AddInstruction( - HloInstruction::CreateGetTupleElement(op->shape(), conditional, - op_index++)); - hoisted_instructions[op] = gtr; + + std::vector visitors; + visitors.reserve(branch_count); + // Visit instructions from the root instruction to the operands using BFS. + for (int i = 0; i < branch_count; i++) { + visitors.emplace_back(BranchVisitor(conditional->branch_computation(i))); } - // Copy boundary instructions out of the conditional. - // Visit the operands before its users and copy it, so that the copied - // user will point to the correct operand. - for (int64 i = to_move_out.size() - 1; i >= 0; i--) { - TF_RETURN_IF_ERROR(CopyOutOfConditional(to_move_out[i], conditional_parent, - hoisted_instructions)); + + // The instructions to be visited within each branch. + std::vector front_instructions(branch_count); + + while (HasNextInstruction(visitors)) { + for (int i = 0; i < branch_count; i++) { + front_instructions[i] = visitors[i].GetNextInstruction(); + } + // If two instructions has the same shape, opcode and its operands has the + // same shape, then this instruction can be moved out of conditional. + if (WorthHoisting(front_instructions[0]) && + InstructionWithinBranchIdentical(front_instructions, + is_layout_sensitive)) { + for (int i = 0; i < branch_count; i++) { + visitors[i].AddInstructionOperands(front_instructions[i]); + visitors[i].AddInstructionToHoist(front_instructions[i]); + } + } else { + for (int i = 0; i < branch_count; i++) { + // If the ops are not identical, these ops and its users will + // be in the boundaries` of the conditional. These ops will be stayed + // within the conditional, but one its only user will be moved out + // of conditional. + visitors[i].AddInstructionToBoundary(front_instructions[i]); + } + } } - VLOG(2) << "Done copy branch instructions out\n" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) - << "\n"; - // Change original users of the conditional to use the correct operands. + + if (visitors[0].HoistInstructionSize() < 1) { + return false; + } + HloInstruction* old_root = conditional->branch_computation(0)->root_instruction(); - for (auto user_instr : old_conditional_users) { - CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement); - auto tuple_opd = down_cast(user_instr); - int64 index = tuple_opd->tuple_index(); - HloInstruction* old_opd = old_root->operands()[index]; - HloInstruction* new_opd = hoisted_instructions[old_opd]; - CHECK(old_opd != nullptr); - CHECK(new_opd != nullptr); - TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); - TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); - } + HloComputation* conditional_parent = conditional->parent(); + // Maps instructions in the conditional body to instructions hoisted outside + // the conditional that compute the same value. + absl::flat_hash_map hoisted_instructions; + // Copy identical instructions out of the conditional. + CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(), + conditional_parent, + &hoisted_instructions); + // If there are instructions to hoist, the root of the conditional must be + // moved out. Change the users of the conditional to the hoisted instruction + // of the new root. + TF_RETURN_IF_ERROR( + ChangeConditionalUsers(conditional, old_root, hoisted_instructions)); + // Create tuple element within each branch and set it as root. - int64 branch_count = conditional->branch_count(); for (int i = 0; i < branch_count; i++) { - auto computation = conditional->branch_computation(i); - std::vector elements; - for (auto b1 : new_boundaries) { - HloInstruction* op = b1.Operands()[i]; - VLOG(1) << "branch count=" << i << "\n"; - CHECK(op != nullptr); - VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n"; - elements.push_back(op); - } - HloInstruction* tuple = - computation->AddInstruction(HloInstruction::CreateTuple(elements)); - computation->set_root_instruction(tuple, true); - VLOG(2) << "computation is :" << computation->ToString() << "\n"; - // Remove hoisted instructions from the branches. - for (auto b2 : to_move_out) { - VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; - TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.Operands()[i])); - } + HloInstruction* tuple = CreateNewRoot( + visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(), + conditional->branch_computation(i)); + conditional->branch_computation(i)->set_root_instruction(tuple, true); + } + // Changes conditional instruction shape to the shape of the new root. + *conditional->mutable_shape() = + conditional->branch_computation(0)->root_instruction()->shape(); + + // Insert GetTupleElement before the instructions whose operands might still + // be within the conditional. + TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional( + visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(), + hoisted_instructions, conditional, conditional_parent)); + + // Remove hoist instructions from the branches. + for (int i = 0; i < branch_count; i++) { + TF_RETURN_IF_ERROR( + RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), + conditional->branch_computation(i))); } - // Change conditional instruction shape to the shape of the new root. - HloInstruction* new_root = - conditional->branch_computation(0)->root_instruction(); - *conditional->mutable_shape() = new_root->shape(); - // - VLOG(2) << "done moving instructions out of branches\n" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) - << "\n"; return true; } -// Group single chains of operands or uses of boundaries into new boundaries -class GroupConnectedBoundaries { - private: - std::unordered_set visited_; - std::vector connected_boundaries_, new_boundaries_; - HloInstruction* conditional_; - bool is_layout_sensitive_; - - public: - explicit GroupConnectedBoundaries(HloInstruction* conditional, - bool is_layout_sensitive) - : conditional_(conditional), is_layout_sensitive_(is_layout_sensitive) {} - // Returns true if `instruction` is worth hoisting out. - bool WorthHoisting(HloInstruction* instruction) { - switch (instruction->opcode()) { - case HloOpcode::kConvert: - // If Convert is after AllReduce, it is worth moving out AllReduce out - // of conditional for AR/CRS combine. If Convert is after other ops such - // as Dot or Convolutional, it is better to keep convert within - // conditional so that convert can be fused with Dot or Convolutional. - // - // TODO(b/154283721): figure out the scenario when convert can be fused - // with AllReduce out of conditional. - switch (instruction->operand(0)->opcode()) { - case HloOpcode::kAllReduce: - case HloOpcode::kReshape: - return true; - default: - VLOG(1) << "Instruction is convert and its operand is not know to " - "be worth hoisting\n"; - return false; - } - case HloOpcode::kAllReduce: - case HloOpcode::kAdd: - case HloOpcode::kConstant: - case HloOpcode::kSubtract: - case HloOpcode::kMultiply: - case HloOpcode::kDivide: - case HloOpcode::kTuple: - case HloOpcode::kSqrt: - case HloOpcode::kReshape: - case HloOpcode::kGetTupleElement: - return true; - default: - VLOG(1) << "Instruction is not known to be worth hoisting\n"; - return false; - } - } - // Calculates the degree of reuses carried by a pair of conditional - // boundaries, if b1 is inside a conditional and b2 is outside. - int64 ReusesBeforeBoundary(HloInstruction* user) { - int64 reuses = 0; - for (auto op : user->operands()) { - // Only consider single-user cases as reuseable. - if (ContainsKey(visited_, op) && op->user_count() == 1) { - reuses += ReusesCarriedBy(op, user); - } - } - VLOG(1) << "cost to be paied after moving out" << user->ToString() << ":" - << reuses << "\n"; - return reuses; - } - - int64 ReusesAfterBoundary(HloInstruction* user) { - CHECK(user != nullptr); - auto all_users = user->users(); - // For now, assume that if an instruction has multiple-consumers, it will - // not be reused (the reuse currently requires duplication in fusion and so - // is expensive). - if (all_users.size() > 1) { - return 0; - } - if (!all_users.empty()) { - auto op = all_users[0]; - int64 reuses = 0; - // Only count reuses that run through the conditional root. - if (op == conditional_->branch_computation(0)->root_instruction()) { - int64 index = op->operand_index(user); - for (auto op2 : conditional_->users()) { - CHECK(op2->opcode() == HloOpcode::kGetTupleElement); - auto tuple_opd = down_cast(op2); - if (index == tuple_opd->tuple_index()) { - all_users = op2->users(); - if (!all_users.empty()) { - reuses += ReusesCarriedBy(user, all_users[0]); - break; - } - } - } - } - VLOG(1) << "reuses to be gained after moving " << user->ToString() << ":" - << reuses << "\n"; - return reuses; - } - return 0; - } - - int64 BenefitForMovingBoundaries(const std::vector& boundaries) { - int64 reuses_before = 0, reuses_after = 0; - for (Boundary b : boundaries) { - auto op = b.Operands()[0]; - if (op == conditional_->branch_computation(0)->root_instruction()) { - continue; - } - reuses_before += ReusesBeforeBoundary(op); - VLOG(1) << "Cost of moving so far: " << reuses_before << "\n"; - reuses_after += ReusesAfterBoundary(op); - VLOG(1) << "Benefit from moving so far : " << reuses_after << "\n"; - } - if (reuses_after == 0 && reuses_before == 0) { - return -1; - } else if (boundaries[0].IsInsideBranch()) { - return reuses_after - reuses_before; - } else { - return reuses_before - reuses_after; - } - } - - Boundary GetNextBoundary(const Boundary& b, int64 op_index) { - Boundary b2(b.GetPosition()); - CHECK(b.Operands().size() == conditional_->branch_count()); - for (int j = 0; j < b.Operands().size(); ++j) { - HloInstruction* inst = b.Operands()[j]; - CHECK(inst != nullptr); - HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index] - : inst->users()[op_index]; - CHECK(op != nullptr); - b2.Operands().push_back(op); - } - return b2; - } - void AddBoundaries(const Boundary& boundary) { - BoundaryVisitor visitor; - visitor.AddToWorkList(boundary); - while (visitor.HasNextBoundary()) { - Boundary b = visitor.PopNextBoundary(); - // if b is already visited, it must have multiple users and is already in - // new boundaries. Skip it. - if (ContainsKey(visited_, b.Operands()[0])) { - continue; - } - VLOG(1) << "visiting boundary " << b.ToString() << "\n"; - if ((b.Operands().size() == 1 || - InstructionWithinBranchIdentical(b.Operands(), - is_layout_sensitive_)) && - WorthHoisting(b.Operands()[0])) { - connected_boundaries_.push_back(b); - VLOG(1) << "boundary can be moved\n"; - int64 operand_count = (b.IsInsideBranch()) - ? b.Operands()[0]->operand_count() - : b.Operands()[0]->users().size(); - for (int i = 0; i < operand_count; i++) { - Boundary b2 = GetNextBoundary(b, i); - int64 b2_count = (b2.IsInsideBranch()) - ? b2.Operands()[0]->user_count() - : b2.Operands()[0]->operand_count(); - // only consider adding an exclusive producor into the same group. - if (b2_count == 1) { - VLOG(2) << "Add operand " << i << " to visit later\n"; - visitor.AddToWorkList(b2); - } else { - VLOG(2) << "Operand " << i << " has multiple uses\n"; - if (!ContainsKey(visited_, b2.Operands()[0])) { - visited_.insert(b2.Operands()[0]); - new_boundaries_.push_back(b2); - } - } - } - } else { - VLOG(1) << "boundary cannot be moved\n"; - visited_.insert(b.Operands()[0]); - new_boundaries_.push_back(b); - } - } - } - std::vector BoundariesToMoveOut(const Boundary& b) { - HloInstruction* inst = b.Operands()[0]; - if (inst->opcode() == HloOpcode::kConditional) { - int branch_count = inst->branch_count(); - // Visit instructions from the root instruction to the operands using BFS. - Boundary boundary_in(Boundary::Position::kInsideBranch); - for (int i = 0; i < branch_count; i++) { - HloComputation* branch_computation = inst->branch_computation(i); - HloInstruction* root_inst = branch_computation->root_instruction(); - CHECK(root_inst != nullptr); - boundary_in.Operands().push_back(root_inst); - } - AddBoundaries(boundary_in); - } - return connected_boundaries_; - } - std::vector BoundariesToMoveIn(const Boundary& b) { - if (b.IsInsideBranch()) { - return std::vector(); - } - AddBoundaries(b); - return connected_boundaries_; - } - std::vector GetNewBoundaries() { return new_boundaries_; } -}; - -ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( - HloInstruction* conditional, const Boundary& cur_boundary, - std::vector& to_move, std::vector& new_boundaries) { - GroupConnectedBoundaries connect(conditional, is_layout_sensitive_); - auto move_out = connect.BoundariesToMoveOut(cur_boundary); - if (!move_out.empty()) { - std::vector next_boundaries = connect.GetNewBoundaries(); - auto benefit = connect.BenefitForMovingBoundaries(move_out); - VLOG(1) << "benefit of moving " << cur_boundary.Operands()[0]->ToString() - << ":" << benefit << "\n"; - if (benefit >= 0) { - new_boundaries = next_boundaries; - to_move = move_out; - return Decision::kMoveOutOfBranch; - } - } - return ConditionalCodeMotion::Decision::kNoChange; -} +} // namespace StatusOr ConditionalCodeMotion::Run(HloModule* module) { - // Gather all the conditional ops in the module ahead of time, to avoid - // potential complications of modifying the code that affecting traversal. - std::vector conditional_ops; - for (auto* comp : module->MakeComputationPostOrder()) { - for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); + bool changed = false; + + if (pursue_full_conditional_code_motion_) { + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } } } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN( + bool result, + MergeIdenticalElements(conditional_op, is_layout_sensitive_)); + changed |= result; + } + + if (changed) { + HloPassPipeline subpipeline("after_conditional_code_motion"); + subpipeline.AddPass(); + subpipeline.AddPass(); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); + changed |= cleanup_changed; + } } - bool changed = false; - std::vector to_move_out, to_move_in, new_boundaries; - for (HloInstruction* conditional : conditional_ops) { - BoundaryVisitor visitor(conditional); - VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n"; - // Boundariess to move out of and to move into the branches. - while (visitor.HasNextBoundary()) { - std::vector to_move, next_boundary; - Boundary boundary = visitor.PopNextBoundary(); - VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; - ConditionalCodeMotion::Decision d = - ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); - switch (d) { - case Decision::kMoveOutOfBranch: - VLOG(2) << "Decision is move out of branch\n"; - to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end()); - break; - case Decision::kMoveIntoBranch: - VLOG(2) << "Decision is move into branch\n"; - to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end()); - break; - case Decision::kNoChange: - VLOG(2) << "Decision is no change\n"; - new_boundaries.push_back(boundary); - break; - } - for (const Boundary& b : next_boundary) { - visitor.AddToWorkList(b); - } - } - TF_ASSIGN_OR_RETURN( - bool result, - MoveInstructionOut(conditional, to_move_out, new_boundaries)); - VLOG(2) << "moving out result:" << result << "\n"; - changed |= result; - } // handling convert rematerialization/hoisting - if (!changed && pursue_full_conditional_code_motion_) { + { std::vector conditional_ops; for (auto* comp : module->MakeComputationPostOrder()) { for (auto* instr : comp->MakeInstructionPostOrder()) { @@ -762,6 +711,7 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { changed |= convert_result; } } + if (changed) { HloPassPipeline subpipeline( "after_conditional_code_motion_after_convert_hoisting"); @@ -771,8 +721,8 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); changed |= cleanup_changed; } + return changed; } -} // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h index d7295058467..95f02833e15 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.h +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -23,80 +23,35 @@ limitations under the License. namespace xla { -namespace conditional_opt { -// At the conceptural level, a boundary can be thought of as representing a -// single virtual operation, except this virtual operation is conditionally -// instantiated into different concrete operations at each conditional branch. -// So a boundary is mapped to a single concrete operation if it is outside of -// conditional branches, and is mapped to a list of instructions if inside the -// branches. This data structure therefore allows a common data structure -// representation of the instructions to be moved, whether they are inside or -// outside of the branches. Subsequently, it allows a common implementation -// basis to be used for both moving instructions out of and for moving them -// inside branches. -class Boundary { - public: - enum class Position { kInsideBranch, kOutsideBranch }; - explicit Boundary(Position p) : position_(p) {} - std::vector& Operands() { return operands_; } - const std::vector& Operands() const { return operands_; } - bool IsInsideBranch() const { return position_ == Position::kInsideBranch; } - bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; } - Position GetPosition() const { return position_; } - bool IsEmpty() const { return operands_.empty(); } - std::string ToString() const { - std::string res; - for (HloInstruction* op : operands_) { - res += op->ToString() + ";"; - } - return res; - } - - private: - // Boundary instructions in the conditional branches, one from each branch - // of the conditional. - std::vector operands_; - Position position_; -}; - -// HLO pass that moves identical ops in/out of conditional. +// ConditionalCodeMotion specializes in hoisting/rematerializing +// unconditional converts in the default mode. +// When pursue_full_conditional_code_motion_ is set to true, the +// full HLO pass moves identical ops out of a conditional in addition to moving +// converts. // - The definition of identical are the shape of the operands are identical // and their properties are identical. +// - Currently, only some types of instructions is supported. +// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in +// the new root. // - Only the identical ops that won't share operands with other ops will // be moved out of conditional. class ConditionalCodeMotion : public HloModulePass { public: // If is_layout_sensitive is true, then the hoist process preserves layout // during identical comparison. Otherwise, layout is ignored. - explicit ConditionalCodeMotion(bool is_layout_sensitive, - bool pursue_full_conditional_code_motion) + explicit ConditionalCodeMotion( + bool is_layout_sensitive = true, + bool pursue_full_conditional_code_motion = false) : is_layout_sensitive_(is_layout_sensitive), pursue_full_conditional_code_motion_( pursue_full_conditional_code_motion) {} absl::string_view name() const override { return "conditional-code-motion"; } StatusOr Run(HloModule* module) override; - // Optimization decision for each boundary of the conditional instruction. - enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange }; - // If the optimization decision is NO_CHANGE, new_boundary is set to nullptr; - // otherwise, it is set to the new boundary after proposed optimization. - virtual Decision ConsiderCodeMotion(HloInstruction* conditional, - const Boundary& cur_boundary, - std::vector& to_move, - std::vector& new_boundaries); - private: const bool is_layout_sensitive_; const bool pursue_full_conditional_code_motion_; - - StatusOr MoveInstructionOut(HloInstruction* conditional, - std::vector& to_move_out, - std::vector& new_boundaries); - StatusOr MoveInstructionIn(HloInstruction* conditional, - std::vector& to_move_in, - std::vector& new_boundaries); }; -} // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b3c5e17094a..38b2b515fa0 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace conditional_opt { +namespace { using ConditionalCodeMotionTest = HloTestBase; namespace op = xla::testing::opcode_matchers; @@ -117,47 +117,6 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); } -TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditional) { - absl::string_view hlo_string = - R"( -HloModule RemoveDotOpOut - -on_true { - %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) - %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 - %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) - %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493) - %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493) - ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) -} - -on_false { - %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) - %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 - %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) - %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) - %sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717) - %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} - ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) -} - -ENTRY main { - pred.1 = pred[] parameter(0) - arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) - arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) - conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false - get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 - ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass(true, true); - ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); - - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); -} - TEST_F(ConditionalCodeMotionTest, MoveConvertOut) { absl::string_view hlo_string = R"( @@ -193,20 +152,8 @@ ENTRY main { ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); - const HloInstruction* conditional = - FindInstruction(module.get(), "conditional"); - const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 2); - const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 2); - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement( - op::GetTupleElement(op::Conditional())))), - op::Convert(op::Reshape(op::GetTupleElement( - op::GetTupleElement(op::Conditional())))))))); + EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert())))); } TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { @@ -226,7 +173,7 @@ on_true { add.2 = f32[] add(add.1, constant.2) add.3 = f32[] add(add.1, constant.3) add.4 = f32[] add(add.3, constant.5) - multiply.1 = f32[] multiply(add.4, constant.4) + multiply.1 = f32[] multiply(add.2, constant.4) ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) } @@ -269,11 +216,13 @@ ENTRY main { const HloComputation* on_false = conditional->branch_computation(1); ASSERT_EQ(on_false->instruction_count(), 9); + // Check only one add and multiply is moved out. HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( - root, AllOf(op::Tuple(op::Multiply(op::GetTupleElement(op::Conditional()), - op::Constant()), - op::GetTupleElement(op::Conditional())))); + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); } TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) { @@ -320,16 +269,16 @@ ENTRY main { const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 1); + ASSERT_EQ(on_true->instruction_count(), 7); const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 1); + ASSERT_EQ(on_false->instruction_count(), 7); - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::Tuple(op::Add( - op::Add(op::GetTupleElement(op::Conditional()), op::Constant()), - op::Add(op::GetTupleElement(op::Conditional()), op::Constant()))))); + // add.3 in on_true will be moved out, add.1 and add.2 will be in condtional + // root. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}))); } TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) { @@ -380,9 +329,24 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); ConditionalCodeMotion pass(true, true); - // If there is no instruction after the conditional, there is no benefit to - // move - ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 9); + + // Check only one add and multiply is moved out. + // add.3 and add.5 can't be moved out because they share operands with + // other instructions. + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); } TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) { @@ -505,8 +469,7 @@ ENTRY main { false_computation=on_false get-first-index = f32[3,3,128,128] get-tuple-element(conditional), index=0 - add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index) - ROOT result = (f32[3,3,128,128]) tuple(add.1) + ROOT result = (f32[3,3,128,128]) tuple(get-first-index) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); @@ -524,14 +487,10 @@ ENTRY main { conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( BF16, {3, 3, 128, 128})}))); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::Tuple(op::Add( - op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))), - op::Convert( - op::AllReduce(op::GetTupleElement(op::Conditional()))))))); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( + op::GetTupleElement(op::Conditional())))))); } -} // namespace conditional_opt +} // namespace } // namespace xla