Update HloModuleDCE to call DCE and WhileLoopSimplifier after removing dead tuple elements in while loops.
PiperOrigin-RevId: 339248021 Change-Id: I76c9a8622d165e06e527b9fbbfba897792c04ce2
This commit is contained in:
parent
d5df62973f
commit
8d74afab45
@ -3655,6 +3655,8 @@ cc_library(
|
|||||||
":hlo_dce",
|
":hlo_dce",
|
||||||
":hlo_liveness_analysis",
|
":hlo_liveness_analysis",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
|
":tuple_simplifier",
|
||||||
|
":while_loop_simplifier",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -38,6 +40,7 @@ namespace {
|
|||||||
|
|
||||||
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
std::vector<HloComputation*> while_body_comps_to_dce;
|
||||||
for (auto* computation : module->computations()) {
|
for (auto* computation : module->computations()) {
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
if (instruction->opcode() != HloOpcode::kWhile) {
|
if (instruction->opcode() != HloOpcode::kWhile) {
|
||||||
@ -60,6 +63,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
|||||||
// Remove dead tuple elements.
|
// Remove dead tuple elements.
|
||||||
const int64 tuple_element_count =
|
const int64 tuple_element_count =
|
||||||
ShapeUtil::TupleElementCount(xla_while->shape());
|
ShapeUtil::TupleElementCount(xla_while->shape());
|
||||||
|
bool modified_while_body_comp = false;
|
||||||
for (int64 i = 0; i < tuple_element_count; ++i) {
|
for (int64 i = 0; i < tuple_element_count; ++i) {
|
||||||
if (liveness->IsLive(xla_while, {i})) {
|
if (liveness->IsLive(xla_while, {i})) {
|
||||||
continue;
|
continue;
|
||||||
@ -79,9 +83,22 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
while_body_root->ReplaceOperandWith(i, pass_thru_gte));
|
while_body_root->ReplaceOperandWith(i, pass_thru_gte));
|
||||||
changed = true;
|
changed = true;
|
||||||
|
modified_while_body_comp = true;
|
||||||
|
}
|
||||||
|
if (modified_while_body_comp) {
|
||||||
|
while_body_comps_to_dce.push_back(while_body_comp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run DCE on while body computations that we modified.
|
||||||
|
for (auto* while_body_comp : while_body_comps_to_dce) {
|
||||||
|
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
||||||
|
HloDCE().RunOnComputation(
|
||||||
|
while_body_comp,
|
||||||
|
/*remove_cross_partition_collective_ops=*/false));
|
||||||
|
changed |= changed_for_computation;
|
||||||
|
}
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,6 +117,15 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
|
|||||||
TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
|
TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
|
||||||
RunWhileDCE(module, liveness.get()));
|
RunWhileDCE(module, liveness.get()));
|
||||||
|
|
||||||
|
// Run the while loop simplifier to remove dead tuple elements.
|
||||||
|
WhileLoopSimplifier while_loop_simplifier;
|
||||||
|
TF_ASSIGN_OR_RETURN(bool while_loop_simplifier_changed,
|
||||||
|
while_loop_simplifier.Run(module));
|
||||||
|
|
||||||
|
TupleSimplifier tuple_simplifier;
|
||||||
|
TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
|
||||||
|
tuple_simplifier.Run(module));
|
||||||
|
|
||||||
// Run HloDCE to clean up any dead code created during HloModuleDCE.
|
// Run HloDCE to clean up any dead code created during HloModuleDCE.
|
||||||
HloDCE hlo_dce;
|
HloDCE hlo_dce;
|
||||||
TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
|
TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
|
||||||
@ -107,7 +133,8 @@ StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
|
|||||||
VLOG(2) << "After HloModuleDCE:";
|
VLOG(2) << "After HloModuleDCE:";
|
||||||
XLA_VLOG_LINES(3, module->ToString());
|
XLA_VLOG_LINES(3, module->ToString());
|
||||||
|
|
||||||
return hlo_module_dce_changed | hlo_dce_changed;
|
return hlo_module_dce_changed | hlo_dce_changed | tuple_simplifier_changed |
|
||||||
|
while_loop_simplifier_changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -66,6 +66,18 @@ class HloModuleDceTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns all of the while loops in 'computation'.
|
||||||
|
std::vector<const HloInstruction*> GetWhileLoops(
|
||||||
|
const HloComputation* computation) {
|
||||||
|
std::vector<const HloInstruction*> while_loops;
|
||||||
|
for (auto* instruction : computation->instructions()) {
|
||||||
|
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||||
|
while_loops.push_back(instruction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return while_loops;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Tests that a while with all outputs live is unmodified.
|
// Tests that a while with all outputs live is unmodified.
|
||||||
@ -182,8 +194,9 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
|||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while", 0));
|
"while", 0));
|
||||||
// While tuple element {1} should now be pass-through after ModuleDCE.
|
// While tuple element {1} should now be pass-through after ModuleDCE.
|
||||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||||
"while", 1));
|
EXPECT_EQ(1, while_loops.size());
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that a tuple element {1} used by condition computation (which appears
|
// Tests that a tuple element {1} used by condition computation (which appears
|
||||||
@ -285,16 +298,16 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
|||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while.2", 1));
|
"while.2", 1));
|
||||||
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||||
// After HloModuleDCE while.1 and while.2 should have pass-thru elements,
|
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
|
||||||
// after being modified to pass through unused tuple element {1}.
|
// after being modified to pass through unused tuple element {1}.
|
||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while.1", 0));
|
"while.1", 0));
|
||||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
|
||||||
"while.1", 1));
|
|
||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while.2", 0));
|
"while.2", 0));
|
||||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||||
"while.2", 1));
|
EXPECT_EQ(2, while_loops.size());
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
|
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
|
||||||
@ -356,12 +369,12 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
|||||||
// After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
|
// After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
|
||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while.1", 1));
|
"while.1", 1));
|
||||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
|
||||||
"while.1", 0));
|
|
||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while.2", 0));
|
"while.2", 0));
|
||||||
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||||
"while.2", 1));
|
EXPECT_EQ(2, while_loops.size());
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[1]->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that a while whose body has outfeed operations is not DCE-ed.
|
// Tests that a while whose body has outfeed operations is not DCE-ed.
|
||||||
@ -431,10 +444,74 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
|
|||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
|
|
||||||
HloModuleDCE dce;
|
HloModuleDCE dce;
|
||||||
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
// Expect TRUE because while loop simplifier will remove dead tuple element.
|
||||||
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
"while", 0));
|
"while", 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloModuleDceTest, TwoWhilesWithDeadWhileLoop) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule TwoWhilesWithDeadWhileLoop
|
||||||
|
SimpleLoop.body0 {
|
||||||
|
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||||
|
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||||
|
constant.1 = s32[] constant(1)
|
||||||
|
add = s32[] add(get-tuple-element.1, constant.1)
|
||||||
|
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
|
||||||
|
ROOT tuple = (s32[], s32[3]{0}) tuple(add, get-tuple-element.2)
|
||||||
|
}
|
||||||
|
SimpleLoop.condition0 {
|
||||||
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
|
constant.2 = s32[] constant(5)
|
||||||
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
|
}
|
||||||
|
SimpleLoop.body1 {
|
||||||
|
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||||
|
get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
|
||||||
|
constant.3 = s32[] constant(1)
|
||||||
|
add.1 = s32[] add(get-tuple-element.4, constant.3)
|
||||||
|
get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
|
||||||
|
ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, get-tuple-element.5)
|
||||||
|
}
|
||||||
|
SimpleLoop.condition1 {
|
||||||
|
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||||
|
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||||
|
constant.4 = s32[] constant(5)
|
||||||
|
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||||
|
}
|
||||||
|
ENTRY SimpleLoop {
|
||||||
|
constant.5 = s32[] constant(0)
|
||||||
|
constant.6 = s32[3]{0} constant({0, 1, 2})
|
||||||
|
tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
|
||||||
|
while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
|
||||||
|
SimpleLoop.condition0, body=SimpleLoop.body0
|
||||||
|
get-tuple-element.7 = s32[3]{0} get-tuple-element(while.1), index=1
|
||||||
|
constant.7 = s32[] constant(0)
|
||||||
|
tuple.3 = (s32[], s32[3]{0}) tuple(constant.7, get-tuple-element.7)
|
||||||
|
while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
|
||||||
|
SimpleLoop.condition1, body=SimpleLoop.body1
|
||||||
|
ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
|
||||||
|
HloModuleDCE dce;
|
||||||
|
// Before HloModuleDCE while.1 and while.2 should have pass-thru elements.
|
||||||
|
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
|
"while.1", 1));
|
||||||
|
EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
|
"while.2", 1));
|
||||||
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
||||||
|
// After HloModuleDCE while.1 and while.2 should have deleted tuple elements,
|
||||||
|
// after being modified to pass through unused tuple element {1}.
|
||||||
|
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||||
|
"while.2", 0));
|
||||||
|
auto while_loops = GetWhileLoops(module->entry_computation());
|
||||||
|
// Dead while.1 should be removed.
|
||||||
|
EXPECT_EQ(1, while_loops.size());
|
||||||
|
EXPECT_EQ(1, ShapeUtil::TupleElementCount(while_loops[0]->shape()));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user