[XLA] Add a test for dynamic shapes vs ComputationLayout

PiperOrigin-RevId: 357787747
Change-Id: I7a0a0403e3c1cfd31e754be1da0568a6ef257973
This commit is contained in:
David Majnemer 2021-02-16 13:05:07 -08:00 committed by TensorFlower Gardener
parent 11baf03f4e
commit 66ebcc356c
2 changed files with 25 additions and 0 deletions

View File

@ -3961,6 +3961,7 @@ tf_cc_test(
deps = [
":algebraic_simplifier",
":computation_layout",
":dynamic_padder",
":hlo",
":hlo_parser",
":layout_assignment",

View File

@ -1422,5 +1422,29 @@ ENTRY entry_computation {
ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, DynamicRoot) {
const char* module_str = R"(
HloModule test_module
ENTRY entry_computation {
param = f32[1,<=16]{0,1} parameter(0)
ROOT abs = f32[1,<=16]{0,1} abs(param)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
computation_layout.mutable_result_layout()->ClearDynamicShape();
AssignLayouts(m.get(), &computation_layout);
const HloInstruction* abs = FindInstruction(m.get(), "abs");
ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
ExpectLayoutIs(abs->shape(), {0, 1});
EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
}
} // namespace
} // namespace xla