diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ba65a29b983..44128b90333 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3961,6 +3961,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":computation_layout", + ":dynamic_padder", ":hlo", ":hlo_parser", ":layout_assignment", diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 304a80c7a52..987ed9009a8 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -1422,5 +1422,29 @@ ENTRY entry_computation { ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0}); } +TEST_F(LayoutAssignmentTest, DynamicRoot) { + const char* module_str = R"( +HloModule test_module + +ENTRY entry_computation { + param = f32[1,<=16]{0,1} parameter(0) + ROOT abs = f32[1,<=16]{0,1} abs(param) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); + computation_layout.mutable_result_layout()->ClearDynamicShape(); + + AssignLayouts(m.get(), &computation_layout); + + const HloInstruction* abs = FindInstruction(m.get(), "abs"); + ExpectLayoutIs(abs->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(abs->shape(), {0, 1}); + EXPECT_TRUE(abs->shape().is_dynamic_dimension(1)); +} + } // namespace } // namespace xla