Fix a bug reshape dynamic dimension inference bug.

This bug happens when we have a reshape adding a degenerated dimension between two dynamic dimensions.

E.g.,: Reshape [<=9, <=2] -> [<=9, 1, <=2]

The fix is for each input dimension, only look at the valid subrange of the dynamic dimensions, instead of looking at the full shape.

PiperOrigin-RevId: 343134740
Change-Id: I2b66dcb45063a176bd85fcc75900428f3a767d6d
This commit is contained in:
Yunxing Dai 2020-11-18 12:39:28 -08:00 committed by TensorFlower Gardener
parent 8fde5290d6
commit 3228d99ead
2 changed files with 27 additions and 1 deletions

View File

@ -911,7 +911,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
output_dynamic_dimension = reshape->inferred_dimension();
if (output_dynamic_dimension == -1) {
// Try find dynamic dimension from the result shape.
for (int64 i = 0; i < reshape->shape().rank(); ++i) {
for (int64 i = output_dim_start; i < output_dim_end; ++i) {
if (reshape->shape().is_dynamic_dimension(i)) {
output_dynamic_dimension = i;
}

View File

@ -1277,5 +1277,31 @@ TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
}
TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
auto six = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
// Reshape [<=9, <=2] into [<=9, 1, <=2]
auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
module_->AddEntryComputation(builder.Build());
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
}
} // namespace
} // namespace xla