Virtualize the AddCopiesOnConditional function on copy insertion.

PiperOrigin-RevId: 355189766
Change-Id: Ib2891e1bf22312b0a8735df1612de0f9a46b1ba4
This commit is contained in:
A. Unique TensorFlower 2021-02-02 09:54:57 -08:00 committed by TensorFlower Gardener
parent 22bf3df6dc
commit c0c74b2f93
3 changed files with 35 additions and 29 deletions

View File

@ -332,35 +332,6 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
// We add copies for all non-phi indices of the true and false computation
// roots, in order to resolve interference. We later rely on
// RemoveUnnecessaryCopies to drop the unnecessary ones.
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
HloInstruction* conditional) {
VLOG(2) << "Adding copies for kConditional instruction "
<< conditional->name();
ShapeTree<bool> indices_to_copy(conditional->shape());
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
conditional, &indices_to_copy)) {
VLOG(2) << "No copies necessary for kWhile instruction "
<< conditional->name();
return Status::OK();
}
for (HloComputation* computation : conditional->branch_computations()) {
HloInstruction* root = computation->root_instruction();
std::vector<HloInstruction*> users = root->users();
TF_ASSIGN_OR_RETURN(
HloInstruction * deep_copy,
computation->DeepCopyInstruction(root, &indices_to_copy));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
}
computation->set_root_instruction(deep_copy);
}
return Status::OK();
}
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
// will remove the unnecessary copies.
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
@ -1006,6 +977,36 @@ class CopyRemover {
} // namespace
// We add copies for all non-phi indices of the true and false computation
// roots, in order to resolve interference. We later rely on
// RemoveUnnecessaryCopies to drop the unnecessary ones.
Status CopyInsertion::AddCopiesForConditional(
const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
VLOG(2) << "Adding copies for kConditional instruction "
<< conditional->name();
ShapeTree<bool> indices_to_copy(conditional->shape());
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
conditional, &indices_to_copy)) {
VLOG(2) << "No copies necessary for kWhile instruction "
<< conditional->name();
return Status::OK();
}
for (HloComputation* computation : conditional->branch_computations()) {
HloInstruction* root = computation->root_instruction();
std::vector<HloInstruction*> users = root->users();
TF_ASSIGN_OR_RETURN(
HloInstruction * deep_copy,
computation->DeepCopyInstruction(root, &indices_to_copy));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
}
computation->set_root_instruction(deep_copy);
}
return Status::OK();
}
// Add kCopy instructions to the given module to guarantee there is no
// live-range interference. Generally interference can only occur around kWhile
// instructions which have update-in-place semantics.

View File

@ -83,6 +83,10 @@ class CopyInsertion : public HloModulePass {
virtual Status AddSpecialCaseCopies(const CallGraph& call_graph,
HloModule* module);
// Add copies for conditional instructions.
virtual Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
HloInstruction* conditional);
// Backend specific function that decides whether an instruction can share
// buffer with its operand.
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;

View File

@ -2166,6 +2166,7 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
// already-bad compile errors even worse.
XLA_VARIADIC_OP_PATTERN(AfterAll);
XLA_VARIADIC_OP_PATTERN(Concatenate);
XLA_VARIADIC_OP_PATTERN(Conditional);
XLA_VARIADIC_OP_PATTERN(CustomCall);
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
XLA_VARIADIC_OP_PATTERN(Fusion);