[XLA] Run copy elision pass in a fixed point
Copies might not be elided due to lifetime collisions with other copies which are yet to be removed. Running copy elision in a fixed point loop lets us elide those copies as well. Fixes #35874 PiperOrigin-RevId: 295773148 Change-Id: I2d70efa775dcb42c21ceb0d5078838dec2d60f06
This commit is contained in:
parent
bd395324d8
commit
a66d4828f3
@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
|
||||
// Mark nondistinct/ambiguous indices.
|
||||
absl::flat_hash_set<const HloBuffer*> seen;
|
||||
absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
std::vector<const HloBuffer*> buffers_at_index =
|
||||
alias_analysis->ComputeBuffersAt(root, index);
|
||||
bool buffer_seen_before = false;
|
||||
for (const HloBuffer* buffer : buffers_at_index) {
|
||||
buffer_seen_before |= !seen.insert(buffer).second;
|
||||
buffer_seen_before |= !seen.emplace(buffer, index).second;
|
||||
}
|
||||
|
||||
if (buffer_seen_before && policy.copy_root_replicated_buffers &&
|
||||
computation == module->entry_computation() &&
|
||||
module->input_output_alias_config().OutputHasAlias(index) &&
|
||||
buffers_at_index.size() == 1) {
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||
module->input_output_alias_config().GetAliasedParameter(index);
|
||||
CHECK(alias) << "Alias does not exist";
|
||||
const ShapeIndex& other_index = seen[buffers_at_index[0]];
|
||||
VLOG(2) << "Output indices " << index.ToString() << " and "
|
||||
<< other_index.ToString() << " are both aliased to "
|
||||
<< alias->parameter_number << " copying " << other_index;
|
||||
add_index_to_copy(root, other_index);
|
||||
return;
|
||||
}
|
||||
|
||||
if (buffers_at_index.size() > 1 ||
|
||||
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
|
||||
VLOG(2) << "Index " << index << " of computation "
|
||||
@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static int64 GetNumExistingCopies(const HloModule* module) {
|
||||
int64 num_existing_copies = 0;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
++num_existing_copies;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_existing_copies;
|
||||
}
|
||||
|
||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
@ -1112,13 +1140,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
}
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
|
||||
int64 num_existing_copies = GetNumExistingCopies(module);
|
||||
bool changed = true;
|
||||
int64 num_iterations = -1;
|
||||
while (changed) {
|
||||
CHECK_LE(++num_iterations, num_existing_copies);
|
||||
changed = false;
|
||||
VLOG(2) << "Running fixpoint iteration " << num_iterations
|
||||
<< " of copy elision";
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1156,17 +1195,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
"Call graph must be flattened before copy insertion.");
|
||||
}
|
||||
|
||||
int64 num_existing_copies = 0;
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
++num_existing_copies;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
|
||||
|
||||
// Simplify the tuple structures introduced by the deep copies. This should be
|
||||
@ -1185,7 +1213,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
||||
*module);
|
||||
|
||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
|
||||
*module);
|
||||
@ -1202,7 +1229,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
|
||||
VLOG(1) << "Num copies before copy-insertion: "
|
||||
<< GetNumExistingCopies(module);
|
||||
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
|
||||
}
|
||||
|
||||
|
@ -2274,5 +2274,69 @@ ENTRY TestComputation {
|
||||
op::While(op::Copy(op::Parameter())));
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FixpointComputationRequired) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[3,3,96,1] parameter(0)
|
||||
param1 = f32[] parameter(1)
|
||||
broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
|
||||
ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
arg0 = f32[3,3,96,1] parameter(0)
|
||||
arg1 = f32[] parameter(1)
|
||||
fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
|
||||
kind=kLoop, calls=fused_computation
|
||||
negate = f32[] negate(f32[] arg1)
|
||||
ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
|
||||
f32[3,3,96,1] fusion,
|
||||
f32[3,3,96,1] arg0,
|
||||
f32[] negate,
|
||||
f32[] arg1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
// Set up the aliasing manually which normally would be set by
|
||||
// alias_passthrough_params pass.
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{3},
|
||||
/*param_number=*/1,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
// There should be no copies inserted.
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule cluster
|
||||
|
||||
ENTRY Entry {
|
||||
%arg = f32[8,28,28,1] parameter(0)
|
||||
%bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
|
||||
ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user