[XLA] Simplify CopyInsertion a bit.
Almost all of the logic of the CopyRemover class was contained in an inner class, BufferValueTracker. This change moves the logic that was inside of CopyRemover into callers, and promotes BufferValueTracker into a new CopyRemover class. PiperOrigin-RevId: 236404892
This commit is contained in:
parent
c4bdb39bf8
commit
0955a2de88
@ -193,8 +193,8 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
|
||||
// Add kCopy instructions around the given kWhile instruction to eliminate any
|
||||
// possible live range interference of HLO values assuming a dependency-based
|
||||
// ordering (HloDependencyOrdering). Copies are added conservatively. There
|
||||
// likely are copies which are not strictly necessary, but there are removed
|
||||
// later in the pass via CopyRemover.
|
||||
// likely are copies which are not strictly necessary, but they are removed
|
||||
// later in the pass via RemoveUnnecessaryCopies.
|
||||
//
|
||||
//
|
||||
// Elements (each ShapeIndex) in the loop state are considered independently. A
|
||||
@ -310,9 +310,9 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We add copies for all the indices of the true and false computation roots,
|
||||
// in order to resolve interference. We later rely on the CopyRemover to drop
|
||||
// the unnecessary ones.
|
||||
// We add copies for all the indices of the true and false computation roots, in
|
||||
// order to resolve interference. We later rely on RemoveUnnecessaryCopies to
|
||||
// drop the unnecessary ones.
|
||||
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* conditional) {
|
||||
VLOG(2) << "Adding copies for kConditional instruction "
|
||||
@ -334,7 +334,8 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
|
||||
// Conservatively adds copies before root instruction of entry computation and
|
||||
// each aliased parameter to resolve interference of aliased input and output
|
||||
// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
|
||||
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
||||
// ones.
|
||||
Status AddCopiesForAliasedInputOutputs(HloModule* module) {
|
||||
HloComputation* entry = module->entry_computation();
|
||||
HloInstruction* root = entry->root_instruction();
|
||||
@ -432,44 +433,6 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Class for removing unnecessary copies from the module.
|
||||
//
|
||||
// kCopy instructions are added conservatively to guarantee no live range
|
||||
// interference between HLO values. This class uses a more fine-grained analysis
|
||||
// to remove some of these added copies which are not strictly necessary.
|
||||
class CopyRemover {
|
||||
public:
|
||||
CopyRemover(const HloAliasAnalysis& alias_analysis,
|
||||
const HloOrdering& ordering, HloModule* module)
|
||||
: module_(module),
|
||||
alias_analysis_(alias_analysis),
|
||||
buffer_value_tracker_(*module, alias_analysis, ordering) {}
|
||||
|
||||
// Try to elide the given copy. The copy is elided if the instruction is not
|
||||
// necessary to prevent live-range interference of HLO values. Returns true if
|
||||
// copy was elided.
|
||||
//
|
||||
// The copy instruction is not actually removed here. Instead it is left for
|
||||
// dead in the graph. Later calls to DCE will remove the instruction.
|
||||
StatusOr<bool> TryElideCopy(HloInstruction* copy) {
|
||||
if (buffer_value_tracker_.TryElideCopy(copy)) {
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy));
|
||||
TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0)));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
string ToString() const {
|
||||
string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Buffer values, in dependency order:\n");
|
||||
for (const HloBuffer& buffer : alias_analysis_.buffers()) {
|
||||
StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
// Class which tracks the HLO values within each HLO buffer in the module
|
||||
// during copy removal.
|
||||
//
|
||||
@ -479,11 +442,11 @@ class CopyRemover {
|
||||
// tracks these value lists as copies are removed from the graph (and value
|
||||
// lists are merged).
|
||||
//
|
||||
// The BufferValueTracker object is initialized to match the state of
|
||||
// The CopyRemover object is initialized to match the state of
|
||||
// HloAliasAnalysis. However, as copies are removed this state diverges. The
|
||||
// values-to-buffer mapping is maintained outside of HloAliasAnalysis because
|
||||
// a fully updatable alias analysis is very slow.
|
||||
class BufferValueTracker {
|
||||
class CopyRemover {
|
||||
public:
|
||||
// The values held in a single HLO buffer are represented using a linked
|
||||
// list. An element type in this list is ValueNode.
|
||||
@ -508,8 +471,7 @@ class CopyRemover {
|
||||
ValueNode* next = nullptr;
|
||||
};
|
||||
|
||||
BufferValueTracker(const HloModule& module,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloOrdering& ordering)
|
||||
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
|
||||
// Construct a list for each HLO buffer in the alias analysis. Maintain a
|
||||
@ -556,7 +518,7 @@ class CopyRemover {
|
||||
TF_DCHECK_OK(Verify());
|
||||
}
|
||||
|
||||
// Add a list containing the given values to BufferValueTracker. This
|
||||
// Add a list containing the given values to CopyRemover. This
|
||||
// represents the values contained in a single buffer. For each value in
|
||||
// 'values' an entry is created in value_to_node which indicates the
|
||||
// respective ValueNode representing that value.
|
||||
@ -617,7 +579,7 @@ class CopyRemover {
|
||||
}
|
||||
}
|
||||
|
||||
~BufferValueTracker() {
|
||||
~CopyRemover() {
|
||||
for (const ValueNode* head : value_lists_) {
|
||||
const ValueNode* p = head;
|
||||
do {
|
||||
@ -638,8 +600,7 @@ class CopyRemover {
|
||||
TF_RET_CHECK(p->next->prev == p);
|
||||
|
||||
const HloInstruction* def = p->value->defining_instruction();
|
||||
if (def->opcode() == HloOpcode::kCopy &&
|
||||
ContainsKey(copy_map_, def)) {
|
||||
if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
|
||||
TF_RET_CHECK(copy_map_.at(def).dest == p);
|
||||
}
|
||||
for (const HloUse* use : p->uses) {
|
||||
@ -676,8 +637,7 @@ class CopyRemover {
|
||||
DCHECK(src != nullptr);
|
||||
DCHECK(dest != nullptr);
|
||||
|
||||
auto is_live_range_before = [this](const ValueNode& a,
|
||||
const ValueNode& b) {
|
||||
auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) {
|
||||
VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
|
||||
if (LiveRangeBefore(a, b)) {
|
||||
VLOG(2) << " Live range of " << a.value->ToShortString()
|
||||
@ -690,8 +650,7 @@ class CopyRemover {
|
||||
}
|
||||
};
|
||||
|
||||
VLOG(3) << copy->name() << " copies value "
|
||||
<< src->value->ToShortString();
|
||||
VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
|
||||
VLOG(3) << "Source buffer values: " << ValueListToString(src);
|
||||
VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
|
||||
|
||||
@ -810,8 +769,7 @@ class CopyRemover {
|
||||
// Splice source buffer values list right after 'prev_dest'.
|
||||
SpliceAfter(first_src, prev_dest);
|
||||
} else {
|
||||
VLOG(2)
|
||||
<< copy->name()
|
||||
VLOG(2) << copy->name()
|
||||
<< " copies value in middle of source buffer to value in middle "
|
||||
"of destination buffer";
|
||||
return false;
|
||||
@ -842,10 +800,9 @@ class CopyRemover {
|
||||
copy_value_node->next->prev = operand_node;
|
||||
|
||||
// Patch up uses. Remove use of copy from operand_node uses.
|
||||
auto it = absl::c_find_if(
|
||||
operand_node->uses, [copy_value_node](const HloUse* use) {
|
||||
return use->instruction ==
|
||||
copy_value_node->value->defining_instruction();
|
||||
auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
|
||||
const HloUse* use) {
|
||||
return use->instruction == copy_value_node->value->defining_instruction();
|
||||
});
|
||||
CHECK(it != operand_node->uses.end());
|
||||
operand_node->uses.erase(it);
|
||||
@ -949,7 +906,7 @@ class CopyRemover {
|
||||
}
|
||||
|
||||
string ToString() const {
|
||||
string out = absl::StrCat("BufferValueTracker:\n");
|
||||
string out = absl::StrCat("CopyRemover:\n");
|
||||
StrAppend(&out, " Def-use chains in each buffer:\n");
|
||||
for (const ValueNode* head : value_lists_) {
|
||||
StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
|
||||
@ -999,13 +956,6 @@ class CopyRemover {
|
||||
absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
|
||||
};
|
||||
|
||||
HloModule* module_;
|
||||
const HloAliasAnalysis& alias_analysis_;
|
||||
|
||||
// Object tracking the HLO values contained in each HLO buffer.
|
||||
BufferValueTracker buffer_value_tracker_;
|
||||
};
|
||||
|
||||
void MaybeDumpModule(const string& message, const HloModule& module) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << message;
|
||||
@ -1159,14 +1109,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
|
||||
CopyRemover copy_remover(*alias_analysis, ordering, module);
|
||||
XLA_VLOG_LINES(3, copy_remover.ToString());
|
||||
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering);
|
||||
if (VLOG_IS_ON(3)) {
|
||||
LOG(INFO) << "Removing unnecessary copies in " << module->name();
|
||||
LOG(INFO) << "Buffer values, in dependency order: ";
|
||||
for (const HloBuffer& buffer : alias_analysis->buffers()) {
|
||||
LOG(INFO) << " HloBuffer " << buffer.id();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user