Merge pull request #31989 from georgepaw:hlo_dce

PiperOrigin-RevId: 266199694
This commit is contained in:
TensorFlower Gardener 2019-08-29 13:13:10 -07:00
commit f0ebc5b745
2 changed files with 38 additions and 20 deletions

View File

@ -35,33 +35,48 @@ limitations under the License.
namespace xla {
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
bool changed = false;
VLOG(3) << "Before dce:";
XLA_VLOG_LINES(3, computation->ToString());
// Remove any dead roots and their dead transitive operands. Collect them
// into a separate list first to avoid problems with iterating through the
// computation's instruction while simultaneously removing instructions.
std::vector<HloInstruction*> dead_roots;
for (auto* instruction : computation->instructions()) {
if (instruction != computation->root_instruction() &&
instruction->user_count() == 0 &&
computation->IsSafelyRemovable(instruction) &&
!instruction->HasSideEffect()) {
dead_roots.push_back(instruction);
}
}
for (HloInstruction* dead_root : dead_roots) {
VLOG(1) << "Removing dead root " << dead_root->ToString()
<< " and it's unused operands";
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(dead_root));
changed = true;
}
if (changed) {
VLOG(3) << "After dce:";
XLA_VLOG_LINES(3, computation->ToString());
}
return changed;
}
StatusOr<bool> HloDCE::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Before dce:";
XLA_VLOG_LINES(2, module->ToString());
// Run DCE on each computation.
for (auto* computation : module->MakeComputationPostOrder()) {
// Remove any dead roots and their dead transitive operands. Collect them
// into a separate list first to avoid problems with iterating through the
// computation's instruction while simultaneously removing instructions.
std::vector<HloInstruction*> dead_roots;
for (auto* instruction : computation->instructions()) {
if (instruction != computation->root_instruction() &&
instruction->user_count() == 0 &&
computation->IsSafelyRemovable(instruction) &&
!instruction->HasSideEffect()) {
dead_roots.push_back(instruction);
}
}
for (HloInstruction* dead_root : dead_roots) {
VLOG(1) << "Removing dead root " << dead_root->ToString()
<< " and it's unused operands";
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(dead_root));
changed = true;
}
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
RunOnComputation(computation));
changed |= changed_for_computation;
}
// Now DCE HloComputations. First, collect the computations that are

View File

@ -38,6 +38,9 @@ class HloDCE : public HloModulePass {
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
// Run DCE on a computation.
static StatusOr<bool> RunOnComputation(HloComputation* computation);
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
StatusOr<bool> Run(HloModule* module) override;