diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 37d523b7573..ecbe8fdadf1 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -519,8 +519,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 20ee4f12e53..0283cc64341 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -21,8 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -64,13 +62,11 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) { instructions_to_outline.clear(); HloInstruction* outline_candidate = instruction; instructions_to_outline.push_back(outline_candidate); - bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast; // Outline sole users with the current instruction. while (CanOutlineWithUser(outline_candidate)) { HloInstruction* prior_candidate = outline_candidate; outline_candidate = *outline_candidate->users().begin(); - all_bitcasts &= outline_candidate->opcode() == HloOpcode::kBitcast; if (std::any_of(outline_candidate->operands().begin(), outline_candidate->operands().end(), [&](const HloInstruction* operand) { @@ -86,24 +82,6 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) { } instructions_to_outline.push_back(outline_candidate); } - // If all instructions in the outline candidates are a bitcast, then create - // a copy at the head of the bitcasts and include it in the outlined - // instructions. The underlying problem is that a computation which forwards - // a parameter buffer to the output is not properly handled by the backends - // or analysis. - // - // This would be better handled by being smarter about choosing outline - // candidates in the first place. - if (all_bitcasts) { - // 'head' is the first instruction in the chain of bitcasts. - HloInstruction* head = instructions_to_outline[0]; - HloInstruction* head_operand = head->mutable_operand(0); - HloInstruction* copy = - entry_computation->AddInstruction(HloInstruction::CreateUnary( - head_operand->shape(), HloOpcode::kCopy, head_operand)); - TF_RETURN_IF_ERROR(head->ReplaceOperandWith(0, copy)); - instructions_to_outline.insert(instructions_to_outline.begin(), copy); - } outlined.insert(instructions_to_outline.begin(), instructions_to_outline.end()); @@ -122,25 +100,6 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) { changed = true; } - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - HloInstruction* root = computation->root_instruction(); - // Copy root instruction if it does not define its own top-level buffer. - // TODO(b/32885001) Remove these copies (at least for the unambiguous case). - // TODO(b/32885001) Perform shallow copy if root value is a tuple. - if (!points_to_analysis->InstructionDefinesBufferAtIndex(root, - /*index=*/{})) { - HloInstruction* copy = computation->AddInstruction( - HloInstruction::CreateUnary(root->shape(), HloOpcode::kCopy, root)); - computation->set_root_instruction(copy); - changed = true; - } - } - XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); XLA_VLOG_LINES(2, module->ToString()); return changed;