diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 8b220e1833b..53efbcadd44 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -911,7 +911,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { output_dynamic_dimension = reshape->inferred_dimension(); if (output_dynamic_dimension == -1) { // Try find dynamic dimension from the result shape. - for (int64 i = 0; i < reshape->shape().rank(); ++i) { + for (int64 i = output_dim_start; i < output_dim_end; ++i) { if (reshape->shape().is_dynamic_dimension(i)) { output_dynamic_dimension = i; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 69f64c31a2f..77c36b375d0 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -1277,5 +1277,31 @@ TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) { EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size); } +TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input")); + auto six = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(6))); + input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize( + ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0)); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize( + ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1)); + + // Reshape [<=9, <=2] into [<=9, 1, <=2] + + auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input)); + + module_->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one); +} + } // namespace } // namespace xla