[XLA] Simplify CopyInsertion a bit.

Almost all of the logic of the CopyRemover class was contained in an inner
class, BufferValueTracker.  This change moves the logic that was inside of
CopyRemover into callers, and promotes BufferValueTracker into a new
CopyRemover class.

PiperOrigin-RevId: 236404892
This commit is contained in:
Justin Lebar 2019-03-01 17:05:32 -08:00 committed by TensorFlower Gardener
parent c4bdb39bf8
commit 0955a2de88

View File

@ -193,8 +193,8 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
// Add kCopy instructions around the given kWhile instruction to eliminate any // Add kCopy instructions around the given kWhile instruction to eliminate any
// possible live range interference of HLO values assuming a dependency-based // possible live range interference of HLO values assuming a dependency-based
// ordering (HloDependencyOrdering). Copies are added conservatively. There // ordering (HloDependencyOrdering). Copies are added conservatively. There
// likely are copies which are not strictly necessary, but there are removed // likely are copies which are not strictly necessary, but they are removed
// later in the pass via CopyRemover. // later in the pass via RemoveUnnecessaryCopies.
// //
// //
// Elements (each ShapeIndex) in the loop state are considered independently. A // Elements (each ShapeIndex) in the loop state are considered independently. A
@ -310,9 +310,9 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
return Status::OK(); return Status::OK();
} }
// We add copies for all the indices of the true and false computation roots, // We add copies for all the indices of the true and false computation roots, in
// in order to resolve interference. We later rely on the CopyRemover to drop // order to resolve interference. We later rely on RemoveUnnecessaryCopies to
// the unnecessary ones. // drop the unnecessary ones.
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
HloInstruction* conditional) { HloInstruction* conditional) {
VLOG(2) << "Adding copies for kConditional instruction " VLOG(2) << "Adding copies for kConditional instruction "
@ -334,7 +334,8 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
// Conservatively adds copies before root instruction of entry computation and // Conservatively adds copies before root instruction of entry computation and
// each aliased parameter to resolve interference of aliased input and output // each aliased parameter to resolve interference of aliased input and output
// buffer. We later rely on the CopyRemover to drop the unnecessary ones. // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
// ones.
Status AddCopiesForAliasedInputOutputs(HloModule* module) { Status AddCopiesForAliasedInputOutputs(HloModule* module) {
HloComputation* entry = module->entry_computation(); HloComputation* entry = module->entry_computation();
HloInstruction* root = entry->root_instruction(); HloInstruction* root = entry->root_instruction();
@ -432,58 +433,20 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) {
return Status::OK(); return Status::OK();
} }
// Class for removing unnecessary copies from the module. // Class which tracks the HLO values within each HLO buffer in the module
// during copy removal.
// //
// kCopy instructions are added conservatively to guarantee no live range // The values are held in a linked list where there is one list for each
// interference between HLO values. This class uses a more fine-grained analysis // buffer. Removing a copy instruction merges together the values in the
// to remove some of these added copies which are not strictly necessary. // source buffer of the copy to the destination buffer of the copy. This class
// tracks these value lists as copies are removed from the graph (and value
// lists are merged).
//
// The CopyRemover object is initialized to match the state of
// HloAliasAnalysis. However, as copies are removed this state diverges. The
// values-to-buffer mapping is maintained outside of HloAliasAnalysis because
// a fully updatable alias analysis is very slow.
class CopyRemover { class CopyRemover {
public:
CopyRemover(const HloAliasAnalysis& alias_analysis,
const HloOrdering& ordering, HloModule* module)
: module_(module),
alias_analysis_(alias_analysis),
buffer_value_tracker_(*module, alias_analysis, ordering) {}
// Try to elide the given copy. The copy is elided if the instruction is not
// necessary to prevent live-range interference of HLO values. Returns true if
// copy was elided.
//
// The copy instruction is not actually removed here. Instead it is left for
// dead in the graph. Later calls to DCE will remove the instruction.
StatusOr<bool> TryElideCopy(HloInstruction* copy) {
if (buffer_value_tracker_.TryElideCopy(copy)) {
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy));
TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0)));
return true;
}
return false;
}
string ToString() const {
string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n");
StrAppend(&out, " Buffer values, in dependency order:\n");
for (const HloBuffer& buffer : alias_analysis_.buffers()) {
StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
}
return out;
}
private:
// Class which tracks the HLO values within each HLO buffer in the module
// during copy removal.
//
// The values are held in a linked list where there is one list for each
// buffer. Removing a copy instruction merges together the values in the
// source buffer of the copy to the destination buffer of the copy. This class
// tracks these value lists as copies are removed from the graph (and value
// lists are merged).
//
// The BufferValueTracker object is initialized to match the state of
// HloAliasAnalysis. However, as copies are removed this state diverges. The
// values-to-buffer mapping is maintained outside of HloAliasAnalysis because
// a fully updatable alias analysis is very slow.
class BufferValueTracker {
public: public:
// The values held in a single HLO buffer are represented using a linked // The values held in a single HLO buffer are represented using a linked
// list. An element type in this list is ValueNode. // list. An element type in this list is ValueNode.
@ -508,8 +471,7 @@ class CopyRemover {
ValueNode* next = nullptr; ValueNode* next = nullptr;
}; };
BufferValueTracker(const HloModule& module, CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
const HloAliasAnalysis& alias_analysis,
const HloOrdering& ordering) const HloOrdering& ordering)
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
// Construct a list for each HLO buffer in the alias analysis. Maintain a // Construct a list for each HLO buffer in the alias analysis. Maintain a
@ -556,7 +518,7 @@ class CopyRemover {
TF_DCHECK_OK(Verify()); TF_DCHECK_OK(Verify());
} }
// Add a list containing the given values to BufferValueTracker. This // Add a list containing the given values to CopyRemover. This
// represents the values contained in a single buffer. For each value in // represents the values contained in a single buffer. For each value in
// 'values' an entry is created in value_to_node which indicates the // 'values' an entry is created in value_to_node which indicates the
// respective ValueNode representing that value. // respective ValueNode representing that value.
@ -617,7 +579,7 @@ class CopyRemover {
} }
} }
~BufferValueTracker() { ~CopyRemover() {
for (const ValueNode* head : value_lists_) { for (const ValueNode* head : value_lists_) {
const ValueNode* p = head; const ValueNode* p = head;
do { do {
@ -638,8 +600,7 @@ class CopyRemover {
TF_RET_CHECK(p->next->prev == p); TF_RET_CHECK(p->next->prev == p);
const HloInstruction* def = p->value->defining_instruction(); const HloInstruction* def = p->value->defining_instruction();
if (def->opcode() == HloOpcode::kCopy && if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
ContainsKey(copy_map_, def)) {
TF_RET_CHECK(copy_map_.at(def).dest == p); TF_RET_CHECK(copy_map_.at(def).dest == p);
} }
for (const HloUse* use : p->uses) { for (const HloUse* use : p->uses) {
@ -676,8 +637,7 @@ class CopyRemover {
DCHECK(src != nullptr); DCHECK(src != nullptr);
DCHECK(dest != nullptr); DCHECK(dest != nullptr);
auto is_live_range_before = [this](const ValueNode& a, auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) {
const ValueNode& b) {
VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
if (LiveRangeBefore(a, b)) { if (LiveRangeBefore(a, b)) {
VLOG(2) << " Live range of " << a.value->ToShortString() VLOG(2) << " Live range of " << a.value->ToShortString()
@ -690,8 +650,7 @@ class CopyRemover {
} }
}; };
VLOG(3) << copy->name() << " copies value " VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
<< src->value->ToShortString();
VLOG(3) << "Source buffer values: " << ValueListToString(src); VLOG(3) << "Source buffer values: " << ValueListToString(src);
VLOG(3) << "Dest buffer values: " << ValueListToString(dest); VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
@ -810,8 +769,7 @@ class CopyRemover {
// Splice source buffer values list right after 'prev_dest'. // Splice source buffer values list right after 'prev_dest'.
SpliceAfter(first_src, prev_dest); SpliceAfter(first_src, prev_dest);
} else { } else {
VLOG(2) VLOG(2) << copy->name()
<< copy->name()
<< " copies value in middle of source buffer to value in middle " << " copies value in middle of source buffer to value in middle "
"of destination buffer"; "of destination buffer";
return false; return false;
@ -842,10 +800,9 @@ class CopyRemover {
copy_value_node->next->prev = operand_node; copy_value_node->next->prev = operand_node;
// Patch up uses. Remove use of copy from operand_node uses. // Patch up uses. Remove use of copy from operand_node uses.
auto it = absl::c_find_if( auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
operand_node->uses, [copy_value_node](const HloUse* use) { const HloUse* use) {
return use->instruction == return use->instruction == copy_value_node->value->defining_instruction();
copy_value_node->value->defining_instruction();
}); });
CHECK(it != operand_node->uses.end()); CHECK(it != operand_node->uses.end());
operand_node->uses.erase(it); operand_node->uses.erase(it);
@ -949,7 +906,7 @@ class CopyRemover {
} }
string ToString() const { string ToString() const {
string out = absl::StrCat("BufferValueTracker:\n"); string out = absl::StrCat("CopyRemover:\n");
StrAppend(&out, " Def-use chains in each buffer:\n"); StrAppend(&out, " Def-use chains in each buffer:\n");
for (const ValueNode* head : value_lists_) { for (const ValueNode* head : value_lists_) {
StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
@ -997,13 +954,6 @@ class CopyRemover {
ValueNode* dest = nullptr; ValueNode* dest = nullptr;
}; };
absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_; absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
};
HloModule* module_;
const HloAliasAnalysis& alias_analysis_;
// Object tracking the HLO values contained in each HLO buffer.
BufferValueTracker buffer_value_tracker_;
}; };
void MaybeDumpModule(const string& message, const HloModule& module) { void MaybeDumpModule(const string& message, const HloModule& module) {
@ -1159,14 +1109,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString()); CopyRemover copy_remover(*module, *alias_analysis, ordering);
if (VLOG_IS_ON(3)) {
LOG(INFO) << "Removing unnecessary copies in " << module->name();
LOG(INFO) << "Buffer values, in dependency order: ";
for (const HloBuffer& buffer : alias_analysis->buffers()) {
LOG(INFO) << " HloBuffer " << buffer.id();
}
}
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) { for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) { if (instruction->opcode() == HloOpcode::kCopy &&
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); copy_remover.TryElideCopy(instruction)) {
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
} }
} }
} }