diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 89cb7a41442..c4549231fb4 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -485,12 +485,13 @@ class CopyRemover { }; CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering) + const HloOrdering& ordering, bool check_live_range_ordering) : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. absl::flat_hash_map value_to_node; + // Perform check only if the default dependence-based ordering is used. for (const HloBuffer& buffer : alias_analysis.buffers()) { // No copies should have been inserted within fused computations, so no // need to remove them. HloOrdering isn't compatible with HloValues inside @@ -498,24 +499,26 @@ class CopyRemover { if (buffer.values().at(0)->defining_instruction()->IsFused()) { continue; } - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (value_a != value_b) { - DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, - dataflow_) || - ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, - dataflow_)) - << value_a->ToShortString() << " and " - << value_b->ToShortString() << " are not ordered"; + if (check_live_range_ordering) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + if (value_a->shape().IsToken()) { + // Token values have no representation and cannot interfere. + continue; + } + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToString() << " and " << value_b->ToString() + << " are not ordered"; + } } } } @@ -729,27 +732,31 @@ class CopyRemover { VLOG(2) << copy->name() << " defines the first value in its buffer"; for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; next_dest = Next(*next_dest)) { - // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!LiveRangeBefore(*src, *next_dest)) { - VLOG(2) << "Not removing the copy: live range of " - << src->value->ToShortString() << " is not before " - << next_dest->value->ToShortString(); - return false; + // Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1); + for (ValueNode* prev_src = src; prev_src != nullptr; + prev_src = Prev(*prev_src)) { + if (!LiveRangeBefore(*prev_src, *next_dest)) { + VLOG(2) << "Not removing the copy: live range of " + << prev_src->value->ToShortString() << " is not before " + << next_dest->value->ToShortString(); + return false; + } } } for (ValueNode* next_src = Next(*src); next_src != nullptr; next_src = Next(*next_src)) { // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. - ValueNode* last_dest = dest->prev; - DCHECK(IsTail(*last_dest)); - if (!LiveRangeBefore(*last_dest, *next_src)) { - VLOG(2) << "Not removing the copy: live range of " - << last_dest->value->ToShortString() << " is not before " - << next_src->value->ToShortString(); - return false; + for (ValueNode* last_dest = dest->prev; last_dest != nullptr; + last_dest = Prev(*dest)) { + if (!LiveRangeBefore(*last_dest, *next_src)) { + VLOG(2) << "Not removing the copy: live range of " + << last_dest->value->ToShortString() << " is not before " + << next_src->value->ToShortString(); + return false; + } } } - + VLOG(2) << "Splice dest after source."; // Splice in destination buffer values list right after 'src'. SpliceAfter(dest, src); } else if (IsTail(*src)) { @@ -769,32 +776,36 @@ class CopyRemover { VLOG(2) << copy->name() << " copies the last value (" << src->value->ToShortString() << ") in its buffer"; - ValueNode* first_src = src->next; - DCHECK(IsHead(*first_src)); - for (ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - prev_dest != nullptr; prev_dest = Prev(*prev_dest)) { - if (!LiveRangeBefore(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. - VLOG(2) << "Not removing the copy: live range of " - << prev_dest->value->ToShortString() << " is not before " - << first_src->value->ToShortString(); - return false; + for (ValueNode* next_src = src->next; next_src != nullptr; + next_src = Next(*next_src)) { + for (ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + prev_dest != nullptr; prev_dest = Prev(*prev_dest)) { + if (!LiveRangeBefore(*prev_dest, *next_src)) { + // Live range of value d_{y-1} is not before s_0. + VLOG(2) << "Not removing the copy: live range of " + << prev_dest->value->ToShortString() << " is not before " + << next_src->value->ToShortString(); + return false; + } } } for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; next_dest = Next(*next_dest)) { - if (!LiveRangeBefore(*src, *next_dest)) { - // Live range of value s_n is not before d_{y+1}. - VLOG(2) << "Not removing the copy: live range of " - << src->value->ToShortString() << " is not before " - << next_dest->value->ToShortString(); - return false; + for (ValueNode* prev_src = src; prev_src != nullptr; + prev_src = Prev(*prev_src)) { + if (!LiveRangeBefore(*prev_src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. + VLOG(2) << "Not removing the copy: live range of " + << prev_src->value->ToShortString() << " is not before " + << next_dest->value->ToShortString(); + return false; + } } } - + VLOG(2) << "Splice src after prev of dest."; // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, Prev(*dest)); + SpliceAfter(src->next, Prev(*dest)); } else { VLOG(2) << copy->name() << " copies value in middle of source buffer to value in middle " @@ -1175,12 +1186,14 @@ static int64 GetNumExistingCopies(const HloModule* module) { } Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, - HloModule* module) { + HloModule* module, + bool check_live_range_ordering) { XLA_VLOG_LINES(4, module->ToString()); TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); - CopyRemover copy_remover(*module, *alias_analysis, ordering); + CopyRemover copy_remover(*module, *alias_analysis, ordering, + check_live_range_ordering); if (VLOG_IS_ON(3)) { LOG(INFO) << "Removing unnecessary copies in " << module->name(); LOG(INFO) << "Buffer values, in dependency order: "; @@ -1200,7 +1213,9 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, VLOG(2) << "Running fixpoint iteration " << num_iterations << " of copy elision"; for (HloComputation* computation : module->computations()) { + VLOG(2) << "computation:" << computation->name() << "\n"; for (HloInstruction* instruction : computation->instructions()) { + VLOG(2) << instruction->ToString() << "\n"; if (instruction->opcode() == HloOpcode::kCopy && copy_remover.TryElideCopy(instruction)) { changed = true; @@ -1260,7 +1275,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { name(), "after adding copies to resolve interference", *module); TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(DependencyHloOrdering(module), module)); + RemoveUnnecessaryCopies(DependencyHloOrdering(module), module, + /*check_live_range_ordering=*/true)); DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", *module); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 22dffbf741c..5f067df6b06 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -63,8 +63,10 @@ class CopyInsertion : public HloModulePass { // Try to remove as many copies from the module as possible without // introducing live range interference. Only copy instructions that are // eligible for copy elision are considered for removal. - Status RemoveUnnecessaryCopies(const HloOrdering& ordering, - HloModule* module); + // If check_live_range_ordering is true, check that live ranges are ordered + // in all the existing aliased buffers. + Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module, + bool check_live_range_ordering = false); // Add copies to address special constraints on the roots of computations not // related to live range interference: