[XLA] Run copy elision pass in a fixed point
Copies might not be elided due to lifetime collisions with other copies which are yet to be removed. Running copy elision in a fixed point loop lets us elide those copies as well. Fixes #35874 PiperOrigin-RevId: 295773148 Change-Id: I2d70efa775dcb42c21ceb0d5078838dec2d60f06
This commit is contained in:
parent
bd395324d8
commit
a66d4828f3
@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
|||||||
HloInstruction* root = computation->root_instruction();
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
|
||||||
// Mark nondistinct/ambiguous indices.
|
// Mark nondistinct/ambiguous indices.
|
||||||
absl::flat_hash_set<const HloBuffer*> seen;
|
absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
|
||||||
ShapeUtil::ForEachSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||||
std::vector<const HloBuffer*> buffers_at_index =
|
std::vector<const HloBuffer*> buffers_at_index =
|
||||||
alias_analysis->ComputeBuffersAt(root, index);
|
alias_analysis->ComputeBuffersAt(root, index);
|
||||||
bool buffer_seen_before = false;
|
bool buffer_seen_before = false;
|
||||||
for (const HloBuffer* buffer : buffers_at_index) {
|
for (const HloBuffer* buffer : buffers_at_index) {
|
||||||
buffer_seen_before |= !seen.insert(buffer).second;
|
buffer_seen_before |= !seen.emplace(buffer, index).second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (buffer_seen_before && policy.copy_root_replicated_buffers &&
|
||||||
|
computation == module->entry_computation() &&
|
||||||
|
module->input_output_alias_config().OutputHasAlias(index) &&
|
||||||
|
buffers_at_index.size() == 1) {
|
||||||
|
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||||
|
module->input_output_alias_config().GetAliasedParameter(index);
|
||||||
|
CHECK(alias) << "Alias does not exist";
|
||||||
|
const ShapeIndex& other_index = seen[buffers_at_index[0]];
|
||||||
|
VLOG(2) << "Output indices " << index.ToString() << " and "
|
||||||
|
<< other_index.ToString() << " are both aliased to "
|
||||||
|
<< alias->parameter_number << " copying " << other_index;
|
||||||
|
add_index_to_copy(root, other_index);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (buffers_at_index.size() > 1 ||
|
if (buffers_at_index.size() > 1 ||
|
||||||
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
|
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
|
||||||
VLOG(2) << "Index " << index << " of computation "
|
VLOG(2) << "Index " << index << " of computation "
|
||||||
@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int64 GetNumExistingCopies(const HloModule* module) {
|
||||||
|
int64 num_existing_copies = 0;
|
||||||
|
for (HloComputation* computation : module->computations()) {
|
||||||
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
|
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||||
|
++num_existing_copies;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_existing_copies;
|
||||||
|
}
|
||||||
|
|
||||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||||
HloModule* module) {
|
HloModule* module) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
@ -1112,16 +1140,27 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
|
||||||
|
int64 num_existing_copies = GetNumExistingCopies(module);
|
||||||
|
bool changed = true;
|
||||||
|
int64 num_iterations = -1;
|
||||||
|
while (changed) {
|
||||||
|
CHECK_LE(++num_iterations, num_existing_copies);
|
||||||
|
changed = false;
|
||||||
|
VLOG(2) << "Running fixpoint iteration " << num_iterations
|
||||||
|
<< " of copy elision";
|
||||||
for (HloComputation* computation : module->computations()) {
|
for (HloComputation* computation : module->computations()) {
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||||
copy_remover.TryElideCopy(instruction)) {
|
copy_remover.TryElideCopy(instruction)) {
|
||||||
|
changed = true;
|
||||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1156,17 +1195,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
|||||||
"Call graph must be flattened before copy insertion.");
|
"Call graph must be flattened before copy insertion.");
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 num_existing_copies = 0;
|
|
||||||
if (VLOG_IS_ON(1)) {
|
|
||||||
for (HloComputation* computation : module->computations()) {
|
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
|
||||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
|
||||||
++num_existing_copies;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
|
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
|
||||||
|
|
||||||
// Simplify the tuple structures introduced by the deep copies. This should be
|
// Simplify the tuple structures introduced by the deep copies. This should be
|
||||||
@ -1185,7 +1213,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
|||||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
||||||
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
||||||
*module);
|
*module);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||||
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
|
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
|
||||||
*module);
|
*module);
|
||||||
@ -1202,7 +1229,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
|
VLOG(1) << "Num copies before copy-insertion: "
|
||||||
|
<< GetNumExistingCopies(module);
|
||||||
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
|
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2274,5 +2274,69 @@ ENTRY TestComputation {
|
|||||||
op::While(op::Copy(op::Parameter())));
|
op::While(op::Copy(op::Parameter())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, FixpointComputationRequired) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule Module
|
||||||
|
|
||||||
|
fused_computation {
|
||||||
|
param0 = f32[3,3,96,1] parameter(0)
|
||||||
|
param1 = f32[] parameter(1)
|
||||||
|
broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
|
||||||
|
ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry_computation {
|
||||||
|
arg0 = f32[3,3,96,1] parameter(0)
|
||||||
|
arg1 = f32[] parameter(1)
|
||||||
|
fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
|
||||||
|
kind=kLoop, calls=fused_computation
|
||||||
|
negate = f32[] negate(f32[] arg1)
|
||||||
|
ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
|
||||||
|
f32[3,3,96,1] fusion,
|
||||||
|
f32[3,3,96,1] arg0,
|
||||||
|
f32[] negate,
|
||||||
|
f32[] arg1)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
// Set up the aliasing manually which normally would be set by
|
||||||
|
// alias_passthrough_params pass.
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1},
|
||||||
|
/*param_number=*/0,
|
||||||
|
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{3},
|
||||||
|
/*param_number=*/1,
|
||||||
|
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||||
|
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
// There should be no copies inserted.
|
||||||
|
EXPECT_EQ(CountCopies(*module), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule cluster
|
||||||
|
|
||||||
|
ENTRY Entry {
|
||||||
|
%arg = f32[8,28,28,1] parameter(0)
|
||||||
|
%bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
|
||||||
|
ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1},
|
||||||
|
/*param_number=*/0,
|
||||||
|
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
EXPECT_EQ(CountCopies(*module), 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user