diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index fa0db40b065..79b010e2f1b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -193,8 +193,8 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, // Add kCopy instructions around the given kWhile instruction to eliminate any // possible live range interference of HLO values assuming a dependency-based // ordering (HloDependencyOrdering). Copies are added conservatively. There -// likely are copies which are not strictly necessary, but there are removed -// later in the pass via CopyRemover. +// likely are copies which are not strictly necessary, but they are removed +// later in the pass via RemoveUnnecessaryCopies. // // // Elements (each ShapeIndex) in the loop state are considered independently. A @@ -310,9 +310,9 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computation roots, -// in order to resolve interference. We later rely on the CopyRemover to drop -// the unnecessary ones. +// We add copies for all the indices of the true and false computation roots, in +// order to resolve interference. We later rely on RemoveUnnecessaryCopies to +// drop the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) { VLOG(2) << "Adding copies for kConditional instruction " @@ -334,7 +334,8 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, // Conservatively adds copies before root instruction of entry computation and // each aliased parameter to resolve interference of aliased input and output -// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary +// ones. Status AddCopiesForAliasedInputOutputs(HloModule* module) { HloComputation* entry = module->entry_computation(); HloInstruction* root = entry->root_instruction(); @@ -432,578 +433,527 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) { return Status::OK(); } -// Class for removing unnecessary copies from the module. +// Class which tracks the HLO values within each HLO buffer in the module +// during copy removal. // -// kCopy instructions are added conservatively to guarantee no live range -// interference between HLO values. This class uses a more fine-grained analysis -// to remove some of these added copies which are not strictly necessary. +// The values are held in a linked list where there is one list for each +// buffer. Removing a copy instruction merges together the values in the +// source buffer of the copy to the destination buffer of the copy. This class +// tracks these value lists as copies are removed from the graph (and value +// lists are merged). +// +// The CopyRemover object is initialized to match the state of +// HloAliasAnalysis. However, as copies are removed this state diverges. The +// values-to-buffer mapping is maintained outside of HloAliasAnalysis because +// a fully updatable alias analysis is very slow. class CopyRemover { public: - CopyRemover(const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering, HloModule* module) - : module_(module), - alias_analysis_(alias_analysis), - buffer_value_tracker_(*module, alias_analysis, ordering) {} - - // Try to elide the given copy. The copy is elided if the instruction is not - // necessary to prevent live-range interference of HLO values. Returns true if - // copy was elided. + // The values held in a single HLO buffer are represented using a linked + // list. An element type in this list is ValueNode. // - // The copy instruction is not actually removed here. Instead it is left for - // dead in the graph. Later calls to DCE will remove the instruction. - StatusOr TryElideCopy(HloInstruction* copy) { - if (buffer_value_tracker_.TryElideCopy(copy)) { - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); - TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); - return true; + // This linked list is hand-rolled to enable efficient splicing of lists + // using only references to list elements without knowing which lists are + // being spliced. std::list requires a reference to the list object to + // splice. + struct ValueNode { + explicit ValueNode(const HloValue* v) : value(v) {} + + const HloValue* value; + + // The uses are maintained outside of HloValue::uses() because + // HloValue::uses() is not updatable (a fully updatable dataflow analysis + // is slow). + std::vector uses; + + // next/prev elements in the linked list. The list is circularly linked so + // these values are never null for elements in the list. + ValueNode* prev = nullptr; + ValueNode* next = nullptr; + }; + + CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering) + : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Construct a list for each HLO buffer in the alias analysis. Maintain a + // map from HloValue to the respective list element representing that + // value. The map is used to construct the copy info map below. + absl::flat_hash_map value_to_node; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + if (value_a->shape().IsToken()) { + // Token values have no representation and cannot interfere. + continue; + } + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToShortString() << " and " + << value_b->ToShortString() << " are not ordered"; + } + } + } + + std::vector values = buffer.values(); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); + + // Create a list containing all of the values in the buffer. + AddValueList(values, &value_to_node); } - return false; + + // Create copy_map_ which contains the source and destination values + // of all copies. + CreateCopyMap(module, value_to_node); + + XLA_VLOG_LINES(3, ToString()); + TF_DCHECK_OK(Verify()); + } + + // Add a list containing the given values to CopyRemover. This + // represents the values contained in a single buffer. For each value in + // 'values' an entry is created in value_to_node which indicates the + // respective ValueNode representing that value. + void AddValueList( + absl::Span values, + absl::flat_hash_map* value_to_node) { + ValueNode* tail = nullptr; + ValueNode* head = nullptr; + for (const HloValue* value : values) { + auto new_node = new ValueNode(value); + (*value_to_node)[value] = new_node; + + // Copy the HLO values's uses into the ValueNode for the value. These + // uses in ValueNode are updated as copies are removed. + new_node->uses.reserve(value->uses().size()); + for (const HloUse& use : value->uses()) { + new_node->uses.push_back(&use); + } + + // Connect the new node into the linked list. + if (tail == nullptr) { + head = new_node; + } else { + tail->next = new_node; + new_node->prev = tail; + } + tail = new_node; + } + + // The linked list is circular so connect the head and tail. + tail->next = head; + head->prev = tail; + value_lists_.insert(head); + } + + // This method also fills in copy_map_ which indicates which nodes + // in the value lists corresponding to the source and destination values of + // kCopy instructions. value_to_node should map each HloValue to its + // respective ValueNode. + void CreateCopyMap( + const HloModule& module, + const absl::flat_hash_map& value_to_node) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + // Add copies with unambiguous source values to the map. Copies with + // ambiguous sources are not removable. + if (instruction->opcode() == HloOpcode::kCopy) { + const HloValueSet& src_value_set = + dataflow_.GetValueSet(instruction->operand(0)); + if (src_value_set.values().size() == 1) { + CopyNodes& copy_node = copy_map_[instruction]; + copy_node.dest = + value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); + copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); + } + } + } + } + } + + ~CopyRemover() { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + const ValueNode* tmp = p->next; + delete p; + p = tmp; + } while (p != head); + } + } + + // Verify invariants within the linked lists. + Status Verify() const { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + // Verify links between elements are consistent. + TF_RET_CHECK(p->prev->next == p); + TF_RET_CHECK(p->next->prev == p); + + const HloInstruction* def = p->value->defining_instruction(); + if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) { + TF_RET_CHECK(copy_map_.at(def).dest == p); + } + for (const HloUse* use : p->uses) { + if (use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, use->instruction)) { + TF_RET_CHECK(copy_map_.at(use->instruction).src == p); + } + } + + p = p->next; + } while (p != head); + } + return Status::OK(); + } + + // Try to elide the given copy. Elision of a copy is possible only if no + // live range interference is introduced by the copy's elimination. If + // elision is possible, then the internal state (value lists) are updated, + // and true is returned. Returns false otherwise. + bool TryElideCopy(const HloInstruction* copy) { + VLOG(2) << "Trying to remove " << copy->name(); + + if (!ContainsKey(copy_map_, copy)) { + VLOG(2) << copy->name() << " is not removable"; + return false; + } + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } + const CopyNodes& copy_node = copy_map_.at(copy); + ValueNode* src = copy_node.src; + ValueNode* dest = copy_node.dest; + DCHECK(src != nullptr); + DCHECK(dest != nullptr); + + auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; + if (LiveRangeBefore(a, b)) { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is before " << b.value->ToShortString(); + return true; + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); + return false; + } + }; + + VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); + VLOG(3) << "Source buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); + + // A kCopy instruction copies an HLO value from a source buffer and + // defines an HLO value in a destination buffer. Most generally, the + // source and destination buffers may each hold more than one value at + // different points in the computation so we define the following: + // + // Values in source buffer: {s_0, ..., s_n} + // Values in destination buffer: {d_0, ..., d_m} + // + // A kCopy instruction between these buffers copies a value s_x in the + // source buffer and defines a value d_y in the destination buffer. The + // elision of a copy merges the source and destination buffers together, + // so the list of values for the source and destination buffers are + // merged. + // + // We handle two different cases for copy elision: + // + // (1) the kCopy defines the first value in the destination buffer (d_0). + // + // (2) the kCopy copies the last value in the source buffer (s_n). + // + // For the remaining case where the kCopy copies a not-last value from the + // source buffer to a not-first value of the destination buffer, the kCopy + // instruction cannot be removed. This case is generated, for example, if + // the kCopy copies a while body parameter of the loop state at one tuple + // index to a different tuple index in the while body root. Removal of the + // copy necessarily results in live range interference of values in the + // loop state at the two different tuple indices. + // + // We can only perform copy elision if the resulting merged values have + // totally ordered live ranges; otherwise the merged buffer would have + // live range interference. + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { + // The copy copies an arbitrary value in the source buffer (call it s_x) + // and defines d_0, the first value in the destination buffer. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} + // + // Removing the copy eliminates d_0, and uses of d_0 become uses of + // s_x. In the above ordering, the live range of d_m must be ordered + // before the live range of s_{x+1} and the definition and all uses of + // s_x must be ordered before the definition of d_1. These conditions + // are checked below prior to elision. + // + // ** Technically it might be possible to have a non-interfering + // non-trivial interleaving of the values of the source and + // destination buffers in the resulting order. However, this case is + // slow and complicated to check and likely not worth it. So instead + // we simply check for the case where *all* values of the destination + // buffer (d_1 through d_m) are spliced into the point where the copy + // used to be. + VLOG(2) << copy->name() << " defines the first value in its buffer"; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); + if (!is_live_range_before(*src, *next_dest)) { + return false; + } + } + ValueNode* next_src = Next(*src); + + if (next_src != nullptr) { + // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. + ValueNode* last_dest = dest->prev; + DCHECK(IsTail(*last_dest)); + if (!is_live_range_before(*last_dest, *next_src)) { + return false; + } + } + + // Splice in destination buffer values list right after 'src'. + SpliceAfter(dest, src); + } else if (IsTail(*src)) { + // The copy copies the last value in the source buffer, s_n, and defines + // an arbitrary value in the destination buffer, d_y. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} + // + // Removing the copy eliminates d_y, and uses of d_y become uses of + // s_n. To enforce the above order, the live range of d_{y-1} must be + // before the live range of s_0, and the live range of s_n must be + // before the live range of d_{y+1}. + // + // ** See comment above in the code handling Case (1). + VLOG(2) << copy->name() << " copies the last value (" + << src->value->ToShortString() << ") in its buffer"; + + ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + DCHECK(prev_dest != nullptr); + ValueNode* first_src = src->next; + DCHECK(IsHead(*first_src)); + if (!is_live_range_before(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + return false; + } + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + if (!is_live_range_before(*src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. + return false; + } + } + + // Splice source buffer values list right after 'prev_dest'. + SpliceAfter(first_src, prev_dest); + } else { + VLOG(2) << copy->name() + << " copies value in middle of source buffer to value in middle " + "of destination buffer"; + return false; + } + + RemoveCopyValue(dest); + + XLA_VLOG_LINES(4, ToString()); + TF_DCHECK_OK(Verify()); + + return true; + } + + // Delete the given ValueNode associated with a elided kCopy + // instruction. This should be called after splicing the value lists of the + // source and destination buffers together. + void RemoveCopyValue(ValueNode* copy_value_node) { + CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), + HloOpcode::kCopy); + ValueNode* operand_node = copy_value_node->prev; + CHECK(operand_node != copy_value_node); + + VLOG(2) << "Removing copy " << operand_node->value->ToShortString() + << " => " << copy_value_node->value->ToShortString(); + + // Splice out the copy value node. + operand_node->next = copy_value_node->next; + copy_value_node->next->prev = operand_node; + + // Patch up uses. Remove use of copy from operand_node uses. + auto it = absl::c_find_if(operand_node->uses, [copy_value_node]( + const HloUse* use) { + return use->instruction == copy_value_node->value->defining_instruction(); + }); + CHECK(it != operand_node->uses.end()); + operand_node->uses.erase(it); + + // If the elided copy has any uses which are themselves kCopy instructions + // then patch up the copy info to reflect the that this kCopy instruction + // has a different operand (the operand of the elided copy). + for (const HloUse* copy_use : copy_value_node->uses) { + operand_node->uses.push_back(copy_use); + if (copy_use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, copy_use->instruction)) { + copy_map_.at(copy_use->instruction).src = operand_node; + } + } + + // Delete the copy info and the value node. + copy_map_.erase(copy_value_node->value->defining_instruction()); + delete copy_value_node; + } + + // Returns true if the live range of given value 'a' is before the live + // range of 'b'. + // + // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not + // updated as copies are removed. + bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { + if (a.uses.empty()) { + VLOG(2) << "Empty uses for " << *a.value; + return ordering_.IsDefinedBefore(*a.value, *b.value); + } + for (const HloUse* use : a.uses) { + VLOG(2) << "Checking use " << *use << " against " << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; + return false; + } + VLOG(2) << "Use " << *use << " is before " << *b.value; + } + return true; + } + + // Returns whether 'node' is the last node in its list. + bool IsTail(const ValueNode& node) const { + return ContainsKey(value_lists_, node.next); + } + + // Returns whether 'node' is the first node in its list. + bool IsHead(const ValueNode& node) const { + return ContainsKey(value_lists_, &node); + } + + // Returns the next node in the list after 'node'. If 'node' is the + // tail, then nullptr is returned. + ValueNode* Next(const ValueNode& node) const { + if (IsTail(node)) { + return nullptr; + } else { + return node.next; + } + } + + // Returns the previous node in the list before 'node'. If 'node' + // is the head, then nullptr is returned. + ValueNode* Prev(const ValueNode& node) const { + if (IsHead(node)) { + return nullptr; + } else { + return node.prev; + } + } + + // Splices the entire linked list with 'head' as its head right after the + // node 'insert_after' in another linked list. + void SpliceAfter(ValueNode* head, ValueNode* insert_after) { + DCHECK(IsHead(*head)); + value_lists_.erase(head); + + ValueNode* tail = head->prev; + tail->next = insert_after->next; + insert_after->next->prev = tail; + + insert_after->next = head; + head->prev = insert_after; + } + + string ValueListToString(const ValueNode* element) { + const ValueNode* head = element; + while (!IsHead(*head)) { + head = Prev(*head); + } + std::vector values; + for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { + values.push_back(p->value); + } + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); - StrAppend(&out, " Buffer values, in dependency order:\n"); - for (const HloBuffer& buffer : alias_analysis_.buffers()) { - StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + string out = absl::StrCat("CopyRemover:\n"); + StrAppend(&out, " Def-use chains in each buffer:\n"); + for (const ValueNode* head : value_lists_) { + StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), + ":\n"); + const ValueNode* p = head; + do { + StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), + "\n"); + + p = p->next; + } while (p != head); + } + StrAppend(&out, " Potentially removable copies:\n"); + for (const auto& pair : copy_map_) { + const HloInstruction* copy = pair.first; + const CopyNodes& copy_info = pair.second; + + StrAppend(&out, " ", copy->name(), " : ", + copy_info.src->value->ToShortString(), " => ", + copy_info.dest->value->ToShortString(), "\n"); } return out; } private: - // Class which tracks the HLO values within each HLO buffer in the module - // during copy removal. - // - // The values are held in a linked list where there is one list for each - // buffer. Removing a copy instruction merges together the values in the - // source buffer of the copy to the destination buffer of the copy. This class - // tracks these value lists as copies are removed from the graph (and value - // lists are merged). - // - // The BufferValueTracker object is initialized to match the state of - // HloAliasAnalysis. However, as copies are removed this state diverges. The - // values-to-buffer mapping is maintained outside of HloAliasAnalysis because - // a fully updatable alias analysis is very slow. - class BufferValueTracker { - public: - // The values held in a single HLO buffer are represented using a linked - // list. An element type in this list is ValueNode. - // - // This linked list is hand-rolled to enable efficient splicing of lists - // using only references to list elements without knowing which lists are - // being spliced. std::list requires a reference to the list object to - // splice. - struct ValueNode { - explicit ValueNode(const HloValue* v) : value(v) {} + const HloDataflowAnalysis& dataflow_; + const HloOrdering& ordering_; - const HloValue* value; + // The heads of all the value lists. Each value list represents the HLO + // values contained in a particular HLO buffer. The values in the list are + // in dependency order. + absl::flat_hash_set value_lists_; - // The uses are maintained outside of HloValue::uses() because - // HloValue::uses() is not updatable (a fully updatable dataflow analysis - // is slow). - std::vector uses; - - // next/prev elements in the linked list. The list is circularly linked so - // these values are never null for elements in the list. - ValueNode* prev = nullptr; - ValueNode* next = nullptr; - }; - - BufferValueTracker(const HloModule& module, - const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering) - : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { - // Construct a list for each HLO buffer in the alias analysis. Maintain a - // map from HloValue to the respective list element representing that - // value. The map is used to construct the copy info map below. - absl::flat_hash_map value_to_node; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (value_a != value_b) { - DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, - dataflow_) || - ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, - dataflow_)) - << value_a->ToShortString() << " and " - << value_b->ToShortString() << " are not ordered"; - } - } - } - - std::vector values = buffer.values(); - absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); - - // Create a list containing all of the values in the buffer. - AddValueList(values, &value_to_node); - } - - // Create copy_map_ which contains the source and destination values - // of all copies. - CreateCopyMap(module, value_to_node); - - XLA_VLOG_LINES(3, ToString()); - TF_DCHECK_OK(Verify()); - } - - // Add a list containing the given values to BufferValueTracker. This - // represents the values contained in a single buffer. For each value in - // 'values' an entry is created in value_to_node which indicates the - // respective ValueNode representing that value. - void AddValueList( - absl::Span values, - absl::flat_hash_map* value_to_node) { - ValueNode* tail = nullptr; - ValueNode* head = nullptr; - for (const HloValue* value : values) { - auto new_node = new ValueNode(value); - (*value_to_node)[value] = new_node; - - // Copy the HLO values's uses into the ValueNode for the value. These - // uses in ValueNode are updated as copies are removed. - new_node->uses.reserve(value->uses().size()); - for (const HloUse& use : value->uses()) { - new_node->uses.push_back(&use); - } - - // Connect the new node into the linked list. - if (tail == nullptr) { - head = new_node; - } else { - tail->next = new_node; - new_node->prev = tail; - } - tail = new_node; - } - - // The linked list is circular so connect the head and tail. - tail->next = head; - head->prev = tail; - value_lists_.insert(head); - } - - // This method also fills in copy_map_ which indicates which nodes - // in the value lists corresponding to the source and destination values of - // kCopy instructions. value_to_node should map each HloValue to its - // respective ValueNode. - void CreateCopyMap( - const HloModule& module, - const absl::flat_hash_map& value_to_node) { - for (HloComputation* computation : module.computations()) { - for (HloInstruction* instruction : computation->instructions()) { - // Add copies with unambiguous source values to the map. Copies with - // ambiguous sources are not removable. - if (instruction->opcode() == HloOpcode::kCopy) { - const HloValueSet& src_value_set = - dataflow_.GetValueSet(instruction->operand(0)); - if (src_value_set.values().size() == 1) { - CopyNodes& copy_node = copy_map_[instruction]; - copy_node.dest = - value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); - copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); - } - } - } - } - } - - ~BufferValueTracker() { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - const ValueNode* tmp = p->next; - delete p; - p = tmp; - } while (p != head); - } - } - - // Verify invariants within the linked lists. - Status Verify() const { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - // Verify links between elements are consistent. - TF_RET_CHECK(p->prev->next == p); - TF_RET_CHECK(p->next->prev == p); - - const HloInstruction* def = p->value->defining_instruction(); - if (def->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, def)) { - TF_RET_CHECK(copy_map_.at(def).dest == p); - } - for (const HloUse* use : p->uses) { - if (use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, use->instruction)) { - TF_RET_CHECK(copy_map_.at(use->instruction).src == p); - } - } - - p = p->next; - } while (p != head); - } - return Status::OK(); - } - - // Try to elide the given copy. Elision of a copy is possible only if no - // live range interference is introduced by the copy's elimination. If - // elision is possible, then the internal state (value lists) are updated, - // and true is returned. Returns false otherwise. - bool TryElideCopy(const HloInstruction* copy) { - VLOG(2) << "Trying to remove " << copy->name(); - - if (!ContainsKey(copy_map_, copy)) { - VLOG(2) << copy->name() << " is not removable"; - return false; - } - if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { - VLOG(2) << copy->name() << " is not removable (shape mismatch)"; - return false; - } - const CopyNodes& copy_node = copy_map_.at(copy); - ValueNode* src = copy_node.src; - ValueNode* dest = copy_node.dest; - DCHECK(src != nullptr); - DCHECK(dest != nullptr); - - auto is_live_range_before = [this](const ValueNode& a, - const ValueNode& b) { - VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; - if (LiveRangeBefore(a, b)) { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is before " << b.value->ToShortString(); - return true; - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); - return false; - } - }; - - VLOG(3) << copy->name() << " copies value " - << src->value->ToShortString(); - VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(dest); - - // A kCopy instruction copies an HLO value from a source buffer and - // defines an HLO value in a destination buffer. Most generally, the - // source and destination buffers may each hold more than one value at - // different points in the computation so we define the following: - // - // Values in source buffer: {s_0, ..., s_n} - // Values in destination buffer: {d_0, ..., d_m} - // - // A kCopy instruction between these buffers copies a value s_x in the - // source buffer and defines a value d_y in the destination buffer. The - // elision of a copy merges the source and destination buffers together, - // so the list of values for the source and destination buffers are - // merged. - // - // We handle two different cases for copy elision: - // - // (1) the kCopy defines the first value in the destination buffer (d_0). - // - // (2) the kCopy copies the last value in the source buffer (s_n). - // - // For the remaining case where the kCopy copies a not-last value from the - // source buffer to a not-first value of the destination buffer, the kCopy - // instruction cannot be removed. This case is generated, for example, if - // the kCopy copies a while body parameter of the loop state at one tuple - // index to a different tuple index in the while body root. Removal of the - // copy necessarily results in live range interference of values in the - // loop state at the two different tuple indices. - // - // We can only perform copy elision if the resulting merged values have - // totally ordered live ranges; otherwise the merged buffer would have - // live range interference. - if (src->next == dest) { - // In the process of eliding copies, its possible for a copy to have the - // same source and destination buffer. In this case, the copy can be - // safely removed. - VLOG(2) << copy->name() << " source and destination buffers are same."; - } else if (IsHead(*dest)) { - // The copy copies an arbitrary value in the source buffer (call it s_x) - // and defines d_0, the first value in the destination buffer. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} - // - // Removing the copy eliminates d_0, and uses of d_0 become uses of - // s_x. In the above ordering, the live range of d_m must be ordered - // before the live range of s_{x+1} and the definition and all uses of - // s_x must be ordered before the definition of d_1. These conditions - // are checked below prior to elision. - // - // ** Technically it might be possible to have a non-interfering - // non-trivial interleaving of the values of the source and - // destination buffers in the resulting order. However, this case is - // slow and complicated to check and likely not worth it. So instead - // we simply check for the case where *all* values of the destination - // buffer (d_1 through d_m) are spliced into the point where the copy - // used to be. - VLOG(2) << copy->name() << " defines the first value in its buffer"; - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!is_live_range_before(*src, *next_dest)) { - return false; - } - } - ValueNode* next_src = Next(*src); - - if (next_src != nullptr) { - // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. - ValueNode* last_dest = dest->prev; - DCHECK(IsTail(*last_dest)); - if (!is_live_range_before(*last_dest, *next_src)) { - return false; - } - } - - // Splice in destination buffer values list right after 'src'. - SpliceAfter(dest, src); - } else if (IsTail(*src)) { - // The copy copies the last value in the source buffer, s_n, and defines - // an arbitrary value in the destination buffer, d_y. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} - // - // Removing the copy eliminates d_y, and uses of d_y become uses of - // s_n. To enforce the above order, the live range of d_{y-1} must be - // before the live range of s_0, and the live range of s_n must be - // before the live range of d_{y+1}. - // - // ** See comment above in the code handling Case (1). - VLOG(2) << copy->name() << " copies the last value (" - << src->value->ToShortString() << ") in its buffer"; - - ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - DCHECK(prev_dest != nullptr); - ValueNode* first_src = src->next; - DCHECK(IsHead(*first_src)); - if (!is_live_range_before(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. - return false; - } - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - if (!is_live_range_before(*src, *next_dest)) { - // Live range of value s_n is not before d_{y+1}. - return false; - } - } - - // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, prev_dest); - } else { - VLOG(2) - << copy->name() - << " copies value in middle of source buffer to value in middle " - "of destination buffer"; - return false; - } - - RemoveCopyValue(dest); - - XLA_VLOG_LINES(4, ToString()); - TF_DCHECK_OK(Verify()); - - return true; - } - - // Delete the given ValueNode associated with a elided kCopy - // instruction. This should be called after splicing the value lists of the - // source and destination buffers together. - void RemoveCopyValue(ValueNode* copy_value_node) { - CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), - HloOpcode::kCopy); - ValueNode* operand_node = copy_value_node->prev; - CHECK(operand_node != copy_value_node); - - VLOG(2) << "Removing copy " << operand_node->value->ToShortString() - << " => " << copy_value_node->value->ToShortString(); - - // Splice out the copy value node. - operand_node->next = copy_value_node->next; - copy_value_node->next->prev = operand_node; - - // Patch up uses. Remove use of copy from operand_node uses. - auto it = absl::c_find_if( - operand_node->uses, [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); - CHECK(it != operand_node->uses.end()); - operand_node->uses.erase(it); - - // If the elided copy has any uses which are themselves kCopy instructions - // then patch up the copy info to reflect the that this kCopy instruction - // has a different operand (the operand of the elided copy). - for (const HloUse* copy_use : copy_value_node->uses) { - operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, copy_use->instruction)) { - copy_map_.at(copy_use->instruction).src = operand_node; - } - } - - // Delete the copy info and the value node. - copy_map_.erase(copy_value_node->value->defining_instruction()); - delete copy_value_node; - } - - // Returns true if the live range of given value 'a' is before the live - // range of 'b'. - // - // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not - // updated as copies are removed. - bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - if (a.uses.empty()) { - VLOG(2) << "Empty uses for " << *a.value; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(2) << "Checking use " << *use << " against " << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Use " << *use << " is NOT before " << *b.value; - return false; - } - VLOG(2) << "Use " << *use << " is before " << *b.value; - } - return true; - } - - // Returns whether 'node' is the last node in its list. - bool IsTail(const ValueNode& node) const { - return ContainsKey(value_lists_, node.next); - } - - // Returns whether 'node' is the first node in its list. - bool IsHead(const ValueNode& node) const { - return ContainsKey(value_lists_, &node); - } - - // Returns the next node in the list after 'node'. If 'node' is the - // tail, then nullptr is returned. - ValueNode* Next(const ValueNode& node) const { - if (IsTail(node)) { - return nullptr; - } else { - return node.next; - } - } - - // Returns the previous node in the list before 'node'. If 'node' - // is the head, then nullptr is returned. - ValueNode* Prev(const ValueNode& node) const { - if (IsHead(node)) { - return nullptr; - } else { - return node.prev; - } - } - - // Splices the entire linked list with 'head' as its head right after the - // node 'insert_after' in another linked list. - void SpliceAfter(ValueNode* head, ValueNode* insert_after) { - DCHECK(IsHead(*head)); - value_lists_.erase(head); - - ValueNode* tail = head->prev; - tail->next = insert_after->next; - insert_after->next->prev = tail; - - insert_after->next = head; - head->prev = insert_after; - } - - string ValueListToString(const ValueNode* element) { - const ValueNode* head = element; - while (!IsHead(*head)) { - head = Prev(*head); - } - std::vector values; - for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { - values.push_back(p->value); - } - return absl::StrCat("{", - absl::StrJoin(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); - } - - string ToString() const { - string out = absl::StrCat("BufferValueTracker:\n"); - StrAppend(&out, " Def-use chains in each buffer:\n"); - for (const ValueNode* head : value_lists_) { - StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), - ":\n"); - const ValueNode* p = head; - do { - StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - absl::StrJoin(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), - "\n"); - - p = p->next; - } while (p != head); - } - StrAppend(&out, " Potentially removable copies:\n"); - for (const auto& pair : copy_map_) { - const HloInstruction* copy = pair.first; - const CopyNodes& copy_info = pair.second; - - StrAppend(&out, " ", copy->name(), " : ", - copy_info.src->value->ToShortString(), " => ", - copy_info.dest->value->ToShortString(), "\n"); - } - return out; - } - - private: - const HloDataflowAnalysis& dataflow_; - const HloOrdering& ordering_; - - // The heads of all the value lists. Each value list represents the HLO - // values contained in a particular HLO buffer. The values in the list are - // in dependency order. - absl::flat_hash_set value_lists_; - - // Copy removal requires fast access to the value list elements - // corresponding to the source and destination values of the kCopy - // instruction. This data structure holds pointers to these elements for - // each kCopy instruction in the graph. - struct CopyNodes { - // The source and destinations values of the kCopy instruction. - ValueNode* src = nullptr; - ValueNode* dest = nullptr; - }; - absl::flat_hash_map copy_map_; + // Copy removal requires fast access to the value list elements + // corresponding to the source and destination values of the kCopy + // instruction. This data structure holds pointers to these elements for + // each kCopy instruction in the graph. + struct CopyNodes { + // The source and destinations values of the kCopy instruction. + ValueNode* src = nullptr; + ValueNode* dest = nullptr; }; - - HloModule* module_; - const HloAliasAnalysis& alias_analysis_; - - // Object tracking the HLO values contained in each HLO buffer. - BufferValueTracker buffer_value_tracker_; + absl::flat_hash_map copy_map_; }; void MaybeDumpModule(const string& message, const HloModule& module) { @@ -1159,14 +1109,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); + + CopyRemover copy_remover(*module, *alias_analysis, ordering); + if (VLOG_IS_ON(3)) { + LOG(INFO) << "Removing unnecessary copies in " << module->name(); + LOG(INFO) << "Buffer values, in dependency order: "; + for (const HloBuffer& buffer : alias_analysis->buffers()) { + LOG(INFO) << " HloBuffer " << buffer.id(); + } + } std::unique_ptr call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + if (instruction->opcode() == HloOpcode::kCopy && + copy_remover.TryElideCopy(instruction)) { + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); } } }