Merge pull request #31989 from georgepaw:hlo_dce
PiperOrigin-RevId: 266199694
This commit is contained in:
commit
f0ebc5b745
@ -35,33 +35,48 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
||||
bool changed = false;
|
||||
VLOG(3) << "Before dce:";
|
||||
XLA_VLOG_LINES(3, computation->ToString());
|
||||
// Remove any dead roots and their dead transitive operands. Collect them
|
||||
// into a separate list first to avoid problems with iterating through the
|
||||
// computation's instruction while simultaneously removing instructions.
|
||||
std::vector<HloInstruction*> dead_roots;
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction != computation->root_instruction() &&
|
||||
instruction->user_count() == 0 &&
|
||||
computation->IsSafelyRemovable(instruction) &&
|
||||
!instruction->HasSideEffect()) {
|
||||
dead_roots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* dead_root : dead_roots) {
|
||||
VLOG(1) << "Removing dead root " << dead_root->ToString()
|
||||
<< " and it's unused operands";
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->RemoveInstructionAndUnusedOperands(dead_root));
|
||||
changed = true;
|
||||
}
|
||||
if (changed) {
|
||||
VLOG(3) << "After dce:";
|
||||
XLA_VLOG_LINES(3, computation->ToString());
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
VLOG(2) << "Before dce:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
// Run DCE on each computation.
|
||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||
// Remove any dead roots and their dead transitive operands. Collect them
|
||||
// into a separate list first to avoid problems with iterating through the
|
||||
// computation's instruction while simultaneously removing instructions.
|
||||
std::vector<HloInstruction*> dead_roots;
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction != computation->root_instruction() &&
|
||||
instruction->user_count() == 0 &&
|
||||
computation->IsSafelyRemovable(instruction) &&
|
||||
!instruction->HasSideEffect()) {
|
||||
dead_roots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* dead_root : dead_roots) {
|
||||
VLOG(1) << "Removing dead root " << dead_root->ToString()
|
||||
<< " and it's unused operands";
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->RemoveInstructionAndUnusedOperands(dead_root));
|
||||
changed = true;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
||||
RunOnComputation(computation));
|
||||
changed |= changed_for_computation;
|
||||
}
|
||||
|
||||
// Now DCE HloComputations. First, collect the computations that are
|
||||
|
@ -38,6 +38,9 @@ class HloDCE : public HloModulePass {
|
||||
~HloDCE() override {}
|
||||
absl::string_view name() const override { return "dce"; }
|
||||
|
||||
// Run DCE on a computation.
|
||||
static StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (instructions were removed).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
Loading…
x
Reference in New Issue
Block a user