Virtualize the AddCopiesOnConditional function on copy insertion.
PiperOrigin-RevId: 355189766 Change-Id: Ib2891e1bf22312b0a8735df1612de0f9a46b1ba4
This commit is contained in:
parent
22bf3df6dc
commit
c0c74b2f93
@ -332,35 +332,6 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We add copies for all non-phi indices of the true and false computation
|
||||
// roots, in order to resolve interference. We later rely on
|
||||
// RemoveUnnecessaryCopies to drop the unnecessary ones.
|
||||
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* conditional) {
|
||||
VLOG(2) << "Adding copies for kConditional instruction "
|
||||
<< conditional->name();
|
||||
ShapeTree<bool> indices_to_copy(conditional->shape());
|
||||
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
|
||||
if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
|
||||
conditional, &indices_to_copy)) {
|
||||
VLOG(2) << "No copies necessary for kWhile instruction "
|
||||
<< conditional->name();
|
||||
return Status::OK();
|
||||
}
|
||||
for (HloComputation* computation : conditional->branch_computations()) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
std::vector<HloInstruction*> users = root->users();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * deep_copy,
|
||||
computation->DeepCopyInstruction(root, &indices_to_copy));
|
||||
for (HloInstruction* user : users) {
|
||||
TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
|
||||
}
|
||||
computation->set_root_instruction(deep_copy);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
|
||||
// will remove the unnecessary copies.
|
||||
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
|
||||
@ -1006,6 +977,36 @@ class CopyRemover {
|
||||
|
||||
} // namespace
|
||||
|
||||
// We add copies for all non-phi indices of the true and false computation
|
||||
// roots, in order to resolve interference. We later rely on
|
||||
// RemoveUnnecessaryCopies to drop the unnecessary ones.
|
||||
Status CopyInsertion::AddCopiesForConditional(
|
||||
const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
|
||||
VLOG(2) << "Adding copies for kConditional instruction "
|
||||
<< conditional->name();
|
||||
ShapeTree<bool> indices_to_copy(conditional->shape());
|
||||
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
|
||||
if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
|
||||
conditional, &indices_to_copy)) {
|
||||
VLOG(2) << "No copies necessary for kWhile instruction "
|
||||
<< conditional->name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (HloComputation* computation : conditional->branch_computations()) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
std::vector<HloInstruction*> users = root->users();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * deep_copy,
|
||||
computation->DeepCopyInstruction(root, &indices_to_copy));
|
||||
for (HloInstruction* user : users) {
|
||||
TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
|
||||
}
|
||||
computation->set_root_instruction(deep_copy);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add kCopy instructions to the given module to guarantee there is no
|
||||
// live-range interference. Generally interference can only occur around kWhile
|
||||
// instructions which have update-in-place semantics.
|
||||
|
@ -83,6 +83,10 @@ class CopyInsertion : public HloModulePass {
|
||||
virtual Status AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
HloModule* module);
|
||||
|
||||
// Add copies for conditional instructions.
|
||||
virtual Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* conditional);
|
||||
|
||||
// Backend specific function that decides whether an instruction can share
|
||||
// buffer with its operand.
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
|
||||
|
@ -2166,6 +2166,7 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
|
||||
// already-bad compile errors even worse.
|
||||
XLA_VARIADIC_OP_PATTERN(AfterAll);
|
||||
XLA_VARIADIC_OP_PATTERN(Concatenate);
|
||||
XLA_VARIADIC_OP_PATTERN(Conditional);
|
||||
XLA_VARIADIC_OP_PATTERN(CustomCall);
|
||||
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
|
||||
XLA_VARIADIC_OP_PATTERN(Fusion);
|
||||
|
Loading…
Reference in New Issue
Block a user