[XLA] Run copy elision pass in a fixed point

Copies might not be elided due to lifetime collisions with other copies which
are yet to be removed. Running copy elision in a fixed point loop lets us elide
those copies as well.

Fixes #35874

PiperOrigin-RevId: 295773148
Change-Id: I2d70efa775dcb42c21ceb0d5078838dec2d60f06
This commit is contained in:
George Karpenkov 2020-02-18 10:58:27 -08:00 committed by TensorFlower Gardener
parent bd395324d8
commit a66d4828f3
2 changed files with 114 additions and 22 deletions

View File

@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloInstruction* root = computation->root_instruction();
// Mark nondistinct/ambiguous indices.
absl::flat_hash_set<const HloBuffer*> seen;
absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const HloBuffer*> buffers_at_index =
alias_analysis->ComputeBuffersAt(root, index);
bool buffer_seen_before = false;
for (const HloBuffer* buffer : buffers_at_index) {
buffer_seen_before |= !seen.insert(buffer).second;
buffer_seen_before |= !seen.emplace(buffer, index).second;
}
if (buffer_seen_before && policy.copy_root_replicated_buffers &&
computation == module->entry_computation() &&
module->input_output_alias_config().OutputHasAlias(index) &&
buffers_at_index.size() == 1) {
absl::optional<HloInputOutputAliasConfig::Alias> alias =
module->input_output_alias_config().GetAliasedParameter(index);
CHECK(alias) << "Alias does not exist";
const ShapeIndex& other_index = seen[buffers_at_index[0]];
VLOG(2) << "Output indices " << index.ToString() << " and "
<< other_index.ToString() << " are both aliased to "
<< alias->parameter_number << " copying " << other_index;
add_index_to_copy(root, other_index);
return;
}
if (buffers_at_index.size() > 1 ||
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
VLOG(2) << "Index " << index << " of computation "
@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
return Status::OK();
}
static int64 GetNumExistingCopies(const HloModule* module) {
int64 num_existing_copies = 0;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
++num_existing_copies;
}
}
}
return num_existing_copies;
}
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
@ -1112,13 +1140,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
}
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
int64 num_existing_copies = GetNumExistingCopies(module);
bool changed = true;
int64 num_iterations = -1;
while (changed) {
CHECK_LE(++num_iterations, num_existing_copies);
changed = false;
VLOG(2) << "Running fixpoint iteration " << num_iterations
<< " of copy elision";
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
changed = true;
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
}
}
}
}
@ -1156,17 +1195,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
"Call graph must be flattened before copy insertion.");
}
int64 num_existing_copies = 0;
if (VLOG_IS_ON(1)) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
++num_existing_copies;
}
}
}
}
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
// Simplify the tuple structures introduced by the deep copies. This should be
@ -1185,7 +1213,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
*module);
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
*module);
@ -1202,7 +1229,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
}
}
}
VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
VLOG(1) << "Num copies before copy-insertion: "
<< GetNumExistingCopies(module);
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}

View File

@ -2274,5 +2274,69 @@ ENTRY TestComputation {
op::While(op::Copy(op::Parameter())));
}
TEST_F(CopyInsertionTest, FixpointComputationRequired) {
const string& hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[3,3,96,1] parameter(0)
param1 = f32[] parameter(1)
broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
}
ENTRY entry_computation {
arg0 = f32[3,3,96,1] parameter(0)
arg1 = f32[] parameter(1)
fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
kind=kLoop, calls=fused_computation
negate = f32[] negate(f32[] arg1)
ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
f32[3,3,96,1] fusion,
f32[3,3,96,1] arg0,
f32[] negate,
f32[] arg1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
// Set up the aliasing manually which normally would be set by
// alias_passthrough_params pass.
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/0,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{3},
/*param_number=*/1,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
InsertCopies(module.get());
// There should be no copies inserted.
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
const string& hlo_string = R"(
HloModule cluster
ENTRY Entry {
%arg = f32[8,28,28,1] parameter(0)
%bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/0,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
} // namespace
} // namespace xla