[XLA] Add while condition required assignment.

Memory space assignment should also ensure to add a required assignment for
while condition parameters.

PiperOrigin-RevId: 306351077
Change-Id: Icda7f6c41435001ac4098d7b94034266030ac9bf
This commit is contained in:
Berkin Ilbeyi 2020-04-13 18:18:42 -07:00 committed by TensorFlower Gardener
parent 498e5b4f6d
commit cf50b1fb78
2 changed files with 58 additions and 0 deletions

View File

@ -842,6 +842,9 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
// Add aliased required assignments.
if (use.instruction->opcode() == HloOpcode::kWhile) {
HloComputation* while_body = use.instruction->while_body();
HloComputation* while_condition = use.instruction->while_condition();
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->root_instruction(),

View File

@ -1562,6 +1562,61 @@ TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
// While loop is the root of the entry computation. We should ensure the
// output of the entry computation remains to be in default memory space.
// Test from //third_party/tensorflow/compiler/xla/tests:while_test
// WhileTest.WhileWithPrngScalarResult.
absl::string_view hlo_string = R"(
HloModule WhileWithPrngScalarResult.18, is_scheduled=true
%fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
}
%body.3 (prev.4: s32[6]) -> s32[6] {
%constant.7 = s32[]{:T(128)} constant(100)
%constant.6 = s32[]{:T(128)} constant(0)
%constant.5 = s32[1]{0:T(128)} constant({1})
%prev.4 = s32[6]{0:T(128)} parameter(0)
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
}
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
%constant.15 = s32[]{:T(128)} constant(1)
%prev.12 = s32[6]{0:T(128)} parameter(0)
%bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
%bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
}
ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
%constant.1 = s32[]{:T(128)} constant(0)
%broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
// Expect the output to have default memory space.
EXPECT_EQ(module->entry_computation()
->root_instruction()
->shape()
.layout()
.memory_space(),
kDefaultMemorySpace);
}
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
// Having control_predecessors on an HLO was preventing us from DCEing an op
// that doesn't have any users (tuple.1). The scheduler assumes the graph is