[XLA] Fix TryElideCopy in copy insertion to traverse all related sources/destinations when considering removing a copy. Also modify the constructor of CopyRemover to check the ordering of aliased live ranges only when using the default dependence-based HloOrdering.

PiperOrigin-RevId: 359548934
Change-Id: I6454cdfcc8d8091719bb8261d88dcb36cbb504cd
This commit is contained in:
A. Unique TensorFlower 2021-02-25 10:03:22 -08:00 committed by TensorFlower Gardener
parent 5ab395bcb6
commit bbc5359754
2 changed files with 75 additions and 57 deletions

View File

@ -485,12 +485,13 @@ class CopyRemover {
};
CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
const HloOrdering& ordering)
const HloOrdering& ordering, bool check_live_range_ordering)
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
// Construct a list for each HLO buffer in the alias analysis. Maintain a
// map from HloValue to the respective list element representing that
// value. The map is used to construct the copy info map below.
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
// Perform check only if the default dependence-based ordering is used.
for (const HloBuffer& buffer : alias_analysis.buffers()) {
// No copies should have been inserted within fused computations, so no
// need to remove them. HloOrdering isn't compatible with HloValues inside
@ -498,24 +499,26 @@ class CopyRemover {
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
continue;
}
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
dataflow_) ||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
dataflow_))
<< value_a->ToShortString() << " and "
<< value_b->ToShortString() << " are not ordered";
if (check_live_range_ordering) {
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
dataflow_) ||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
dataflow_))
<< value_a->ToString() << " and " << value_b->ToString()
<< " are not ordered";
}
}
}
}
@ -729,27 +732,31 @@ class CopyRemover {
VLOG(2) << copy->name() << " defines the first value in its buffer";
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
next_dest = Next(*next_dest)) {
// Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
if (!LiveRangeBefore(*src, *next_dest)) {
VLOG(2) << "Not removing the copy: live range of "
<< src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
// Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
for (ValueNode* prev_src = src; prev_src != nullptr;
prev_src = Prev(*prev_src)) {
if (!LiveRangeBefore(*prev_src, *next_dest)) {
VLOG(2) << "Not removing the copy: live range of "
<< prev_src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
}
}
}
for (ValueNode* next_src = Next(*src); next_src != nullptr;
next_src = Next(*next_src)) {
// Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
ValueNode* last_dest = dest->prev;
DCHECK(IsTail(*last_dest));
if (!LiveRangeBefore(*last_dest, *next_src)) {
VLOG(2) << "Not removing the copy: live range of "
<< last_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
for (ValueNode* last_dest = dest->prev; last_dest != nullptr;
last_dest = Prev(*dest)) {
if (!LiveRangeBefore(*last_dest, *next_src)) {
VLOG(2) << "Not removing the copy: live range of "
<< last_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
}
}
}
VLOG(2) << "Splice dest after source.";
// Splice in destination buffer values list right after 'src'.
SpliceAfter(dest, src);
} else if (IsTail(*src)) {
@ -769,32 +776,36 @@ class CopyRemover {
VLOG(2) << copy->name() << " copies the last value ("
<< src->value->ToShortString() << ") in its buffer";
ValueNode* first_src = src->next;
DCHECK(IsHead(*first_src));
for (ValueNode* prev_dest = Prev(*dest);
// nullptr condition handled above in the first 'if' case.
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
if (!LiveRangeBefore(*prev_dest, *first_src)) {
// Live range of value d_{y-1} is not before s_0.
VLOG(2) << "Not removing the copy: live range of "
<< prev_dest->value->ToShortString() << " is not before "
<< first_src->value->ToShortString();
return false;
for (ValueNode* next_src = src->next; next_src != nullptr;
next_src = Next(*next_src)) {
for (ValueNode* prev_dest = Prev(*dest);
// nullptr condition handled above in the first 'if' case.
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
if (!LiveRangeBefore(*prev_dest, *next_src)) {
// Live range of value d_{y-1} is not before s_0.
VLOG(2) << "Not removing the copy: live range of "
<< prev_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
}
}
}
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
next_dest = Next(*next_dest)) {
if (!LiveRangeBefore(*src, *next_dest)) {
// Live range of value s_n is not before d_{y+1}.
VLOG(2) << "Not removing the copy: live range of "
<< src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
for (ValueNode* prev_src = src; prev_src != nullptr;
prev_src = Prev(*prev_src)) {
if (!LiveRangeBefore(*prev_src, *next_dest)) {
// Live range of value s_n is not before d_{y+1}.
VLOG(2) << "Not removing the copy: live range of "
<< prev_src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
}
}
}
VLOG(2) << "Splice src after prev of dest.";
// Splice source buffer values list right after 'prev_dest'.
SpliceAfter(first_src, Prev(*dest));
SpliceAfter(src->next, Prev(*dest));
} else {
VLOG(2) << copy->name()
<< " copies value in middle of source buffer to value in middle "
@ -1175,12 +1186,14 @@ static int64 GetNumExistingCopies(const HloModule* module) {
}
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module) {
HloModule* module,
bool check_live_range_ordering) {
XLA_VLOG_LINES(4, module->ToString());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, can_share_buffer_));
CopyRemover copy_remover(*module, *alias_analysis, ordering);
CopyRemover copy_remover(*module, *alias_analysis, ordering,
check_live_range_ordering);
if (VLOG_IS_ON(3)) {
LOG(INFO) << "Removing unnecessary copies in " << module->name();
LOG(INFO) << "Buffer values, in dependency order: ";
@ -1200,7 +1213,9 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
VLOG(2) << "Running fixpoint iteration " << num_iterations
<< " of copy elision";
for (HloComputation* computation : module->computations()) {
VLOG(2) << "computation:" << computation->name() << "\n";
for (HloInstruction* instruction : computation->instructions()) {
VLOG(2) << instruction->ToString() << "\n";
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
changed = true;
@ -1260,7 +1275,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
name(), "after adding copies to resolve interference", *module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module,
/*check_live_range_ordering=*/true));
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
*module);
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));

View File

@ -63,8 +63,10 @@ class CopyInsertion : public HloModulePass {
// Try to remove as many copies from the module as possible without
// introducing live range interference. Only copy instructions that are
// eligible for copy elision are considered for removal.
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module);
// If check_live_range_ordering is true, check that live ranges are ordered
// in all the existing aliased buffers.
Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module,
bool check_live_range_ordering = false);
// Add copies to address special constraints on the roots of computations not
// related to live range interference: