[XLA] Add while condition required assignment.
Memory space assignment should also ensure to add a required assignment for while condition parameters. PiperOrigin-RevId: 306351077 Change-Id: Icda7f6c41435001ac4098d7b94034266030ac9bf
This commit is contained in:
parent
498e5b4f6d
commit
cf50b1fb78
tensorflow/compiler/xla/service
@ -842,6 +842,9 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
|
||||
// Add aliased required assignments.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
HloComputation* while_condition = use.instruction->while_condition();
|
||||
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
|
||||
use.operand_index, aliased_allocation);
|
||||
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
|
||||
use.operand_index, aliased_allocation);
|
||||
AddAliasedRequiredAssignment(while_body->root_instruction(),
|
||||
|
@ -1562,6 +1562,61 @@ TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
|
||||
AssignMemorySpace(module.get());
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
|
||||
// While loop is the root of the entry computation. We should ensure the
|
||||
// output of the entry computation remains to be in default memory space.
|
||||
// Test from //third_party/tensorflow/compiler/xla/tests:while_test
|
||||
// WhileTest.WhileWithPrngScalarResult.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule WhileWithPrngScalarResult.18, is_scheduled=true
|
||||
|
||||
%fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
}
|
||||
|
||||
%body.3 (prev.4: s32[6]) -> s32[6] {
|
||||
%constant.7 = s32[]{:T(128)} constant(100)
|
||||
%constant.6 = s32[]{:T(128)} constant(0)
|
||||
%constant.5 = s32[1]{0:T(128)} constant({1})
|
||||
%prev.4 = s32[6]{0:T(128)} parameter(0)
|
||||
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
|
||||
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
|
||||
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
|
||||
%constant.15 = s32[]{:T(128)} constant(1)
|
||||
%prev.12 = s32[6]{0:T(128)} parameter(0)
|
||||
%bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
|
||||
%bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
|
||||
ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
|
||||
}
|
||||
|
||||
ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
|
||||
%constant.1 = s32[]{:T(128)} constant(0)
|
||||
%broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
|
||||
ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
AssignMemorySpace(module.get());
|
||||
// Expect the output to have default memory space.
|
||||
EXPECT_EQ(module->entry_computation()
|
||||
->root_instruction()
|
||||
->shape()
|
||||
.layout()
|
||||
.memory_space(),
|
||||
kDefaultMemorySpace);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
||||
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
||||
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
||||
|
Loading…
Reference in New Issue
Block a user