[XLA] Do not simplify loops with trip count = 1 if there is an infeed in it.

PiperOrigin-RevId: 303217179
Change-Id: Ida39742d25319b878fbc10b675b2133bf2e6d5b4
This commit is contained in:
Yunxing Dai 2020-03-26 16:25:57 -07:00 committed by TensorFlower Gardener
parent 3009664be0
commit 55d96a7c83
3 changed files with 37 additions and 45 deletions

View File

@ -2228,6 +2228,7 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
@ -21,8 +23,10 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
@ -1010,6 +1014,35 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
continue;
}
// Do not simplify the loop away when there is a side-effectful op,
// otherwise the infeed op may not inherit the data dependency from
// the while loop.
//
// Example: while_body (param_a) {
// param_a = parameter(0)
// infeed2 = infeed()
// }
//
// infeed1 = ...
// while = while(infeed1), body=while_body // infeed2 has implicit
// dependency on infeed1.
//
// After simplification:
//
// infeed1 = ...
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
// // can be scheduled after infeed2.
//
bool has_side_effects = absl::c_any_of(
while_op->called_computations(), [](const HloComputation* computation) {
return computation->HasSideEffect();
});
if (has_side_effects) {
VLOG(2) << "Not attempting to simplify while loop because it contains a "
"side-effecting node: "
<< while_op->ToShortString();
continue;
}
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
changed |= result;

View File

@ -209,8 +209,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
}
// We can simplify loops whose bodies contain infeed or other side-effecting
// instructions other than send/recv.
// We can't simplify loops whose bodies contain infeed or other side-effecting
// instructions.
TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
HloComputation* computation = m->entry_computation();
@ -220,8 +220,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
auto token = while_body->AddInstruction(HloInstruction::CreateToken());
while_body->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple());
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
}
// We don't simplify trip-count-1 loops whose *conditions* contain infeed or
@ -445,47 +444,6 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
}
// Check that we can remove unused loop operands even if the loop contains a
// side-effecting instruction.
TEST_F(WhileLoopSimplifierTest,
RemoveUnusedLoopOperandsDespiteSideEffectingOps) {
const string hlo_string = R"(
HloModule RemoveUnusedOperands
body {
loop_var = (s32[]) parameter(0)
gte0 = s32[] get-tuple-element(loop_var), index=0
token0 = token[] after-all()
unused = ((s32[], pred[]), token[]) infeed(token0)
ROOT tuple = (s32[]) tuple(gte0)
}
cond {
loop_var = (s32[]) parameter(0)
ROOT constant = pred[] constant(true)
}
ENTRY RemoveUnusedOperands {
x = s32[] parameter(0)
tuple.1 = (s32[]) tuple(s32[] x)
ROOT while = (s32[]) while((s32[]) tuple.1),
condition=cond, body=body
}
)";
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
// The original while instruction is still left in the module as a dead
// instruction, find a while instruction with a different name as the new
// while instruction.
const auto& instrs = m->entry_computation()->instructions();
HloInstruction* new_while_op =
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
return (instr->opcode() == HloOpcode::kWhile &&
instr->name() != "while");
});
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape()))
<< new_while_op->shape().ToString();
}
TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
const string hlo_string = R"(
HloModule BodyHasNonTupleRoot