Merge pull request #31989 from georgepaw:hlo_dce
PiperOrigin-RevId: 266199694
This commit is contained in:
commit
f0ebc5b745
@ -35,33 +35,48 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
||||||
|
bool changed = false;
|
||||||
|
VLOG(3) << "Before dce:";
|
||||||
|
XLA_VLOG_LINES(3, computation->ToString());
|
||||||
|
// Remove any dead roots and their dead transitive operands. Collect them
|
||||||
|
// into a separate list first to avoid problems with iterating through the
|
||||||
|
// computation's instruction while simultaneously removing instructions.
|
||||||
|
std::vector<HloInstruction*> dead_roots;
|
||||||
|
for (auto* instruction : computation->instructions()) {
|
||||||
|
if (instruction != computation->root_instruction() &&
|
||||||
|
instruction->user_count() == 0 &&
|
||||||
|
computation->IsSafelyRemovable(instruction) &&
|
||||||
|
!instruction->HasSideEffect()) {
|
||||||
|
dead_roots.push_back(instruction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (HloInstruction* dead_root : dead_roots) {
|
||||||
|
VLOG(1) << "Removing dead root " << dead_root->ToString()
|
||||||
|
<< " and it's unused operands";
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
computation->RemoveInstructionAndUnusedOperands(dead_root));
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
if (changed) {
|
||||||
|
VLOG(3) << "After dce:";
|
||||||
|
XLA_VLOG_LINES(3, computation->ToString());
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<bool> HloDCE::Run(HloModule* module) {
|
StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
||||||
VLOG(2) << "Before dce:";
|
VLOG(2) << "Before dce:";
|
||||||
XLA_VLOG_LINES(2, module->ToString());
|
XLA_VLOG_LINES(2, module->ToString());
|
||||||
|
|
||||||
|
// Run DCE on each computation.
|
||||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||||
// Remove any dead roots and their dead transitive operands. Collect them
|
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
||||||
// into a separate list first to avoid problems with iterating through the
|
RunOnComputation(computation));
|
||||||
// computation's instruction while simultaneously removing instructions.
|
changed |= changed_for_computation;
|
||||||
std::vector<HloInstruction*> dead_roots;
|
|
||||||
for (auto* instruction : computation->instructions()) {
|
|
||||||
if (instruction != computation->root_instruction() &&
|
|
||||||
instruction->user_count() == 0 &&
|
|
||||||
computation->IsSafelyRemovable(instruction) &&
|
|
||||||
!instruction->HasSideEffect()) {
|
|
||||||
dead_roots.push_back(instruction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (HloInstruction* dead_root : dead_roots) {
|
|
||||||
VLOG(1) << "Removing dead root " << dead_root->ToString()
|
|
||||||
<< " and it's unused operands";
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation->RemoveInstructionAndUnusedOperands(dead_root));
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now DCE HloComputations. First, collect the computations that are
|
// Now DCE HloComputations. First, collect the computations that are
|
||||||
|
@ -38,6 +38,9 @@ class HloDCE : public HloModulePass {
|
|||||||
~HloDCE() override {}
|
~HloDCE() override {}
|
||||||
absl::string_view name() const override { return "dce"; }
|
absl::string_view name() const override { return "dce"; }
|
||||||
|
|
||||||
|
// Run DCE on a computation.
|
||||||
|
static StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||||
|
|
||||||
// Run the pass on the given module. Returns whether the module was changed
|
// Run the pass on the given module. Returns whether the module was changed
|
||||||
// (instructions were removed).
|
// (instructions were removed).
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user