[XLA] Fix TryElideCopy in copy insertion to traverse all related sources/destinations when considering removing a copy. Also modify the constructor of CopyRemover to check the ordering of aliased live ranges only when using the default dependence-based HloOrdering.
PiperOrigin-RevId: 359548934 Change-Id: I6454cdfcc8d8091719bb8261d88dcb36cbb504cd
This commit is contained in:
parent
5ab395bcb6
commit
bbc5359754
@ -485,12 +485,13 @@ class CopyRemover {
|
||||
};
|
||||
|
||||
CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloOrdering& ordering)
|
||||
const HloOrdering& ordering, bool check_live_range_ordering)
|
||||
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
|
||||
// Construct a list for each HLO buffer in the alias analysis. Maintain a
|
||||
// map from HloValue to the respective list element representing that
|
||||
// value. The map is used to construct the copy info map below.
|
||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||
// Perform check only if the default dependence-based ordering is used.
|
||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||
// No copies should have been inserted within fused computations, so no
|
||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
||||
@ -498,24 +499,26 @@ class CopyRemover {
|
||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
||||
continue;
|
||||
}
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
// between copies added around aliased operations (kWhile) guarantees
|
||||
// this strict order.
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
if (value_a->shape().IsToken()) {
|
||||
// Token values have no representation and cannot interfere.
|
||||
continue;
|
||||
}
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (value_a != value_b) {
|
||||
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
|
||||
dataflow_) ||
|
||||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
|
||||
dataflow_))
|
||||
<< value_a->ToShortString() << " and "
|
||||
<< value_b->ToShortString() << " are not ordered";
|
||||
if (check_live_range_ordering) {
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
// between copies added around aliased operations (kWhile) guarantees
|
||||
// this strict order.
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
if (value_a->shape().IsToken()) {
|
||||
// Token values have no representation and cannot interfere.
|
||||
continue;
|
||||
}
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (value_a != value_b) {
|
||||
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
|
||||
dataflow_) ||
|
||||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
|
||||
dataflow_))
|
||||
<< value_a->ToString() << " and " << value_b->ToString()
|
||||
<< " are not ordered";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -729,27 +732,31 @@ class CopyRemover {
|
||||
VLOG(2) << copy->name() << " defines the first value in its buffer";
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
// Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
// Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
|
||||
for (ValueNode* prev_src = src; prev_src != nullptr;
|
||||
prev_src = Prev(*prev_src)) {
|
||||
if (!LiveRangeBefore(*prev_src, *next_dest)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ValueNode* next_src = Next(*src); next_src != nullptr;
|
||||
next_src = Next(*next_src)) {
|
||||
// Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
|
||||
ValueNode* last_dest = dest->prev;
|
||||
DCHECK(IsTail(*last_dest));
|
||||
if (!LiveRangeBefore(*last_dest, *next_src)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< last_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* last_dest = dest->prev; last_dest != nullptr;
|
||||
last_dest = Prev(*dest)) {
|
||||
if (!LiveRangeBefore(*last_dest, *next_src)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< last_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Splice dest after source.";
|
||||
// Splice in destination buffer values list right after 'src'.
|
||||
SpliceAfter(dest, src);
|
||||
} else if (IsTail(*src)) {
|
||||
@ -769,32 +776,36 @@ class CopyRemover {
|
||||
VLOG(2) << copy->name() << " copies the last value ("
|
||||
<< src->value->ToShortString() << ") in its buffer";
|
||||
|
||||
ValueNode* first_src = src->next;
|
||||
DCHECK(IsHead(*first_src));
|
||||
for (ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
|
||||
if (!LiveRangeBefore(*prev_dest, *first_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< first_src->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* next_src = src->next; next_src != nullptr;
|
||||
next_src = Next(*next_src)) {
|
||||
for (ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
|
||||
if (!LiveRangeBefore(*prev_dest, *next_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
// Live range of value s_n is not before d_{y+1}.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* prev_src = src; prev_src != nullptr;
|
||||
prev_src = Prev(*prev_src)) {
|
||||
if (!LiveRangeBefore(*prev_src, *next_dest)) {
|
||||
// Live range of value s_n is not before d_{y+1}.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Splice src after prev of dest.";
|
||||
// Splice source buffer values list right after 'prev_dest'.
|
||||
SpliceAfter(first_src, Prev(*dest));
|
||||
SpliceAfter(src->next, Prev(*dest));
|
||||
} else {
|
||||
VLOG(2) << copy->name()
|
||||
<< " copies value in middle of source buffer to value in middle "
|
||||
@ -1175,12 +1186,14 @@ static int64 GetNumExistingCopies(const HloModule* module) {
|
||||
}
|
||||
|
||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module) {
|
||||
HloModule* module,
|
||||
bool check_live_range_ordering) {
|
||||
XLA_VLOG_LINES(4, module->ToString());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering);
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering,
|
||||
check_live_range_ordering);
|
||||
if (VLOG_IS_ON(3)) {
|
||||
LOG(INFO) << "Removing unnecessary copies in " << module->name();
|
||||
LOG(INFO) << "Buffer values, in dependency order: ";
|
||||
@ -1200,7 +1213,9 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
VLOG(2) << "Running fixpoint iteration " << num_iterations
|
||||
<< " of copy elision";
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
VLOG(2) << "computation:" << computation->name() << "\n";
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
VLOG(2) << instruction->ToString() << "\n";
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
changed = true;
|
||||
@ -1260,7 +1275,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
name(), "after adding copies to resolve interference", *module);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module,
|
||||
/*check_live_range_ordering=*/true));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
||||
*module);
|
||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||
|
@ -63,8 +63,10 @@ class CopyInsertion : public HloModulePass {
|
||||
// Try to remove as many copies from the module as possible without
|
||||
// introducing live range interference. Only copy instructions that are
|
||||
// eligible for copy elision are considered for removal.
|
||||
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module);
|
||||
// If check_live_range_ordering is true, check that live ranges are ordered
|
||||
// in all the existing aliased buffers.
|
||||
Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module,
|
||||
bool check_live_range_ordering = false);
|
||||
|
||||
// Add copies to address special constraints on the roots of computations not
|
||||
// related to live range interference:
|
||||
|
Loading…
Reference in New Issue
Block a user