Update HloModuleDCE to call DCE and WhileLoopSimplifier after removing dead tuple elements in while loops.
PiperOrigin-RevId: 339248021 Change-Id: I76c9a8622d165e06e527b9fbbfba897792c04ce2
This commit is contained in:
parent
d5df62973f
commit
8d74afab45
@ -3655,6 +3655,8 @@ cc_library(
|
||||
":hlo_dce",
|
||||
":hlo_liveness_analysis",
|
||||
":hlo_pass",
|
||||
":tuple_simplifier",
|
||||
":while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -38,6 +40,7 @@ namespace {
|
||||
|
||||
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||
bool changed = false;
|
||||
std::vector<HloComputation*> while_body_comps_to_dce;
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kWhile) {
|
||||
@ -60,6 +63,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||
// Remove dead tuple elements.
|
||||
const int64 tuple_element_count =
|
||||
ShapeUtil::TupleElementCount(xla_while->shape());
|
||||
bool modified_while_body_comp = false;
|
||||
for (int64 i = 0; i < tuple_element_count; ++i) {
|
||||
if (liveness->IsLive(xla_while, {i})) {
|
||||
continue;
|
||||
@ -79,9 +83,22 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
while_body_root->ReplaceOperandWith(i, pass_thru_gte));
|
||||
changed = true;
|
||||
modified_while_body_comp = true;
|
||||
}
|
||||
if (modified_while_body_comp) {
|
||||
while_body_comps_to_dce.push_back(while_body_comp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run DCE on while body computations that we modified.
|
||||
for (auto* while_body_comp : while_body_comps_to_dce) {
|
||||
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
||||
HloDCE().RunOnComputation(
|
||||
while_body_comp,
|
||||
/*remove_cross_partition_collective_ops=*/false));
|
||||
changed |= changed_for_computation;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
@ -100,6 +117,15 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
|
||||
RunWhileDCE(module, liveness.get()));
|
||||
|
||||
// Run the while loop simplifier to remove dead tuple elements.
|
||||
WhileLoopSimplifier while_loop_simplifier;
|
||||
TF_ASSIGN_OR_RETURN(bool while_loop_simplifier_changed,
|
||||
while_loop_simplifier.Run(module));
|
||||
|
||||
TupleSimplifier tuple_simplifier;
|
||||
TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
|
||||
tuple_simplifier.Run(module));
|
||||
|
||||
// Run HloDCE to clean up any dead code created during HloModuleDCE.
|
||||
HloDCE hlo_dce;
|
||||
TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
|
||||
@ -107,7 +133,8 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
|
||||
VLOG(2) << "After HloModuleDCE:";
|
||||
XLA_VLOG_LINES(3, module->ToString());
|
||||
|
||||
return hlo_module_dce_changed | hlo_dce_changed;
|
||||
return hlo_module_dce_changed | hlo_dce_changed | tuple_simplifier_changed |
|
||||
while_loop_simplifier_changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -66,6 +66,18 @@ class HloModuleDceTest : public HloTestBase {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns all of the while loops in 'computation'.
|
||||
std::vector<const HloInstruction*> GetWhileLoops(
|
||||
const HloComputation* computation) {
|
||||
std::vector<const HloInstruction*> while_loops;
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
while_loops.push_back(instruction);
|
||||
}
|
||||
}
|
||||
return while_loops;
|
||||
}
|
||||
};
|
||||
|
||||
// Tests that a while with all outputs live is unmodified.
|
||||
@ -182,8 +194,9 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while", 0));
|
||||
// While tuple element {1} should now be pass-through after ModuleDCE.
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while", 1));
|
||||
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||
EXPECT_EQ(1, while_loops.size());
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||
}
|
||||
|
||||
// Tests that a tuple element {1} used by condition computation (which appears
|
||||
@ -285,16 +298,16 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 1));
|
||||
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||
// After HloModuleDCE while.1 and while.2 should have pass-thru elements,
|
||||
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
|
||||
// after being modified to pass through unused tuple element {1}.
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.1", 0));
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.1", 1));
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 0));
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 1));
|
||||
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||
EXPECT_EQ(2, while_loops.size());
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
|
||||
}
|
||||
|
||||
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
|
||||
@ -356,12 +369,12 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
// After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.1", 1));
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.1", 0));
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 0));
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 1));
|
||||
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||
EXPECT_EQ(2, while_loops.size());
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
|
||||
}
|
||||
|
||||
// Tests that a while whose body has outfeed operations is not DCE-ed.
|
||||
@ -431,10 +444,74 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
|
||||
.ValueOrDie();
|
||||
|
||||
HloModuleDCE dce;
|
||||
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
||||
// Expect TRUE because while loop simplifier will remove dead tuple element.
|
||||
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while", 0));
|
||||
}
|
||||
|
||||
TEST_F(HloModuleDceTest, TwoWhilesWithDeadWhileLoop) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule TwoWhilesWithDeadWhileLoop
|
||||
SimpleLoop.body0 {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
|
||||
ROOT tuple = (s32[], s32[3]{0}) tuple(add, get-tuple-element.2)
|
||||
}
|
||||
SimpleLoop.condition0 {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
SimpleLoop.body1 {
|
||||
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
|
||||
constant.3 = s32[] constant(1)
|
||||
add.1 = s32[] add(get-tuple-element.4, constant.3)
|
||||
get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
|
||||
ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, get-tuple-element.5)
|
||||
}
|
||||
SimpleLoop.condition1 {
|
||||
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||
constant.4 = s32[] constant(5)
|
||||
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.5 = s32[] constant(0)
|
||||
constant.6 = s32[3]{0} constant({0, 1, 2})
|
||||
tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
|
||||
while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
|
||||
SimpleLoop.condition0, body=SimpleLoop.body0
|
||||
get-tuple-element.7 = s32[3]{0} get-tuple-element(while.1), index=1
|
||||
constant.7 = s32[] constant(0)
|
||||
tuple.3 = (s32[], s32[3]{0}) tuple(constant.7, get-tuple-element.7)
|
||||
while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
|
||||
SimpleLoop.condition1, body=SimpleLoop.body1
|
||||
ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
HloModuleDCE dce;
|
||||
// Before HloModuleDCE while.1 and while.2 should have pass-thru elements.
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.1", 1));
|
||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 1));
|
||||
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
|
||||
// after being modified to pass through unused tuple element {1}.
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while.2", 0));
|
||||
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||
// Dead while.1 should be removed.
|
||||
EXPECT_EQ(1, while_loops.size());
|
||||
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user