Update HloModuleDCE to call DCE and WhileLoopSimplifier after removing dead tuple elements in while loops.

PiperOrigin-RevId: 339248021
Change-Id: I76c9a8622d165e06e527b9fbbfba897792c04ce2
This commit is contained in:
A. Unique TensorFlower 2020-10-27 07:54:01 -07:00 committed by TensorFlower Gardener
parent d5df62973f
commit 8d74afab45
3 changed files with 119 additions and 13 deletions

View File

@ -3655,6 +3655,8 @@ cc_library(
":hlo_dce",
":hlo_liveness_analysis",
":hlo_pass",
":tuple_simplifier",
":while_loop_simplifier",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -24,6 +24,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -38,6 +40,7 @@ namespace {
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
std::vector<HloComputation*> while_body_comps_to_dce;
for (auto* computation : module->computations()) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kWhile) {
@ -60,6 +63,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
// Remove dead tuple elements.
const int64 tuple_element_count =
ShapeUtil::TupleElementCount(xla_while->shape());
bool modified_while_body_comp = false;
for (int64 i = 0; i < tuple_element_count; ++i) {
if (liveness->IsLive(xla_while, {i})) {
continue;
@ -79,9 +83,22 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
TF_RETURN_IF_ERROR(
while_body_root->ReplaceOperandWith(i, pass_thru_gte));
changed = true;
modified_while_body_comp = true;
}
if (modified_while_body_comp) {
while_body_comps_to_dce.push_back(while_body_comp);
}
}
}
// Run DCE on while body computations that we modified.
for (auto* while_body_comp : while_body_comps_to_dce) {
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
HloDCE().RunOnComputation(
while_body_comp,
/*remove_cross_partition_collective_ops=*/false));
changed |= changed_for_computation;
}
return changed;
}
@ -100,6 +117,15 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
RunWhileDCE(module, liveness.get()));
// Run the while loop simplifier to remove dead tuple elements.
WhileLoopSimplifier while_loop_simplifier;
TF_ASSIGN_OR_RETURN(bool while_loop_simplifier_changed,
while_loop_simplifier.Run(module));
TupleSimplifier tuple_simplifier;
TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
tuple_simplifier.Run(module));
// Run HloDCE to clean up any dead code created during HloModuleDCE.
HloDCE hlo_dce;
TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
@ -107,7 +133,8 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
VLOG(2) << "After HloModuleDCE:";
XLA_VLOG_LINES(3, module->ToString());
return hlo_module_dce_changed | hlo_dce_changed;
return hlo_module_dce_changed | hlo_dce_changed | tuple_simplifier_changed |
while_loop_simplifier_changed;
}
} // namespace xla

View File

@ -66,6 +66,18 @@ class HloModuleDceTest : public HloTestBase {
}
return false;
}
// Returns all of the while loops in 'computation'.
std::vector<const HloInstruction*> GetWhileLoops(
const HloComputation* computation) {
std::vector<const HloInstruction*> while_loops;
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile) {
while_loops.push_back(instruction);
}
}
return while_loops;
}
};
// Tests that a while with all outputs live is unmodified.
@ -182,8 +194,9 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while", 0));
// While tuple element {1} should now be pass-through after ModuleDCE.
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while", 1));
auto while_loops = GetWhileLoops(module->entry_computation());
EXPECT_EQ(1, while_loops.size());
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
}
// Tests that a tuple element {1} used by condition computation (which appears
@ -285,16 +298,16 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 1));
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
// After HloModuleDCE while.1 and while.2 should have pass-thru elements,
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
// after being modified to pass through unused tuple element {1}.
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.1", 0));
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.1", 1));
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 0));
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 1));
auto while_loops = GetWhileLoops(module->entry_computation());
EXPECT_EQ(2, while_loops.size());
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
}
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
@ -356,12 +369,12 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
// After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.1", 1));
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.1", 0));
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 0));
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 1));
auto while_loops = GetWhileLoops(module->entry_computation());
EXPECT_EQ(2, while_loops.size());
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
}
// Tests that a while whose body has outfeed operations is not DCE-ed.
@ -431,10 +444,74 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
.ValueOrDie();
HloModuleDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
// Expect TRUE because while loop simplifier will remove dead tuple element.
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while", 0));
}
TEST_F(HloModuleDceTest, TwoWhilesWithDeadWhileLoop) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TwoWhilesWithDeadWhileLoop
SimpleLoop.body0 {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
constant.1 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.1)
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
ROOT tuple = (s32[], s32[3]{0}) tuple(add, get-tuple-element.2)
}
SimpleLoop.condition0 {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(5)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
SimpleLoop.body1 {
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
constant.3 = s32[] constant(1)
add.1 = s32[] add(get-tuple-element.4, constant.3)
get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, get-tuple-element.5)
}
SimpleLoop.condition1 {
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
constant.4 = s32[] constant(5)
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
}
ENTRY SimpleLoop {
constant.5 = s32[] constant(0)
constant.6 = s32[3]{0} constant({0, 1, 2})
tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
SimpleLoop.condition0, body=SimpleLoop.body0
get-tuple-element.7 = s32[3]{0} get-tuple-element(while.1), index=1
constant.7 = s32[] constant(0)
tuple.3 = (s32[], s32[3]{0}) tuple(constant.7, get-tuple-element.7)
while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
SimpleLoop.condition1, body=SimpleLoop.body1
ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
})")
.ValueOrDie();
HloModuleDCE dce;
// Before HloModuleDCE while.1 and while.2 should have pass-thru elements.
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.1", 1));
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 1));
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
// after being modified to pass through unused tuple element {1}.
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while.2", 0));
auto while_loops = GetWhileLoops(module->entry_computation());
// Dead while.1 should be removed.
EXPECT_EQ(1, while_loops.size());
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
}
} // namespace
} // namespace xla