[XLA] Do not simplify loops with trip count = 1 if there is an infeed in it.
PiperOrigin-RevId: 303217179 Change-Id: Ida39742d25319b878fbc10b675b2133bf2e6d5b4
This commit is contained in:
parent
3009664be0
commit
55d96a7c83
@ -2228,6 +2228,7 @@ cc_library(
|
||||
":while_loop_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -21,8 +23,10 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
|
||||
@ -1010,6 +1014,35 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not simplify the loop away when there is a side-effectful op,
|
||||
// otherwise the infeed op may not inherit the data dependency from
|
||||
// the while loop.
|
||||
//
|
||||
// Example: while_body (param_a) {
|
||||
// param_a = parameter(0)
|
||||
// infeed2 = infeed()
|
||||
// }
|
||||
//
|
||||
// infeed1 = ...
|
||||
// while = while(infeed1), body=while_body // infeed2 has implicit
|
||||
// dependency on infeed1.
|
||||
//
|
||||
// After simplification:
|
||||
//
|
||||
// infeed1 = ...
|
||||
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
|
||||
// // can be scheduled after infeed2.
|
||||
//
|
||||
bool has_side_effects = absl::c_any_of(
|
||||
while_op->called_computations(), [](const HloComputation* computation) {
|
||||
return computation->HasSideEffect();
|
||||
});
|
||||
if (has_side_effects) {
|
||||
VLOG(2) << "Not attempting to simplify while loop because it contains a "
|
||||
"side-effecting node: "
|
||||
<< while_op->ToShortString();
|
||||
continue;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
|
||||
changed |= result;
|
||||
|
||||
|
@ -209,8 +209,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
|
||||
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
// We can simplify loops whose bodies contain infeed or other side-effecting
|
||||
// instructions other than send/recv.
|
||||
// We can't simplify loops whose bodies contain infeed or other side-effecting
|
||||
// instructions.
|
||||
TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
|
||||
auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1);
|
||||
HloComputation* computation = m->entry_computation();
|
||||
@ -220,8 +220,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) {
|
||||
auto token = while_body->AddInstruction(HloInstruction::CreateToken());
|
||||
while_body->AddInstruction(HloInstruction::CreateInfeed(
|
||||
ShapeUtil::MakeShape(F32, {1}), token, "config"));
|
||||
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple());
|
||||
EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
// We don't simplify trip-count-1 loops whose *conditions* contain infeed or
|
||||
@ -445,47 +444,6 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
||||
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
|
||||
}
|
||||
|
||||
// Check that we can remove unused loop operands even if the loop contains a
|
||||
// side-effecting instruction.
|
||||
TEST_F(WhileLoopSimplifierTest,
|
||||
RemoveUnusedLoopOperandsDespiteSideEffectingOps) {
|
||||
const string hlo_string = R"(
|
||||
HloModule RemoveUnusedOperands
|
||||
body {
|
||||
loop_var = (s32[]) parameter(0)
|
||||
gte0 = s32[] get-tuple-element(loop_var), index=0
|
||||
token0 = token[] after-all()
|
||||
unused = ((s32[], pred[]), token[]) infeed(token0)
|
||||
ROOT tuple = (s32[]) tuple(gte0)
|
||||
}
|
||||
cond {
|
||||
loop_var = (s32[]) parameter(0)
|
||||
ROOT constant = pred[] constant(true)
|
||||
}
|
||||
ENTRY RemoveUnusedOperands {
|
||||
x = s32[] parameter(0)
|
||||
tuple.1 = (s32[]) tuple(s32[] x)
|
||||
ROOT while = (s32[]) while((s32[]) tuple.1),
|
||||
condition=cond, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
|
||||
// The original while instruction is still left in the module as a dead
|
||||
// instruction, find a while instruction with a different name as the new
|
||||
// while instruction.
|
||||
const auto& instrs = m->entry_computation()->instructions();
|
||||
HloInstruction* new_while_op =
|
||||
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
|
||||
return (instr->opcode() == HloOpcode::kWhile &&
|
||||
instr->name() != "while");
|
||||
});
|
||||
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape()))
|
||||
<< new_while_op->shape().ToString();
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
|
||||
const string hlo_string = R"(
|
||||
HloModule BodyHasNonTupleRoot
|
||||
|
Loading…
x
Reference in New Issue
Block a user