Fix a bug reshape dynamic dimension inference bug.
This bug happens when we have a reshape adding a degenerated dimension between two dynamic dimensions. E.g.,: Reshape [<=9, <=2] -> [<=9, 1, <=2] The fix is for each input dimension, only look at the valid subrange of the dynamic dimensions, instead of looking at the full shape. PiperOrigin-RevId: 343134740 Change-Id: I2b66dcb45063a176bd85fcc75900428f3a767d6d
This commit is contained in:
parent
8fde5290d6
commit
3228d99ead
@ -911,7 +911,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
output_dynamic_dimension = reshape->inferred_dimension();
|
||||
if (output_dynamic_dimension == -1) {
|
||||
// Try find dynamic dimension from the result shape.
|
||||
for (int64 i = 0; i < reshape->shape().rank(); ++i) {
|
||||
for (int64 i = output_dim_start; i < output_dim_end; ++i) {
|
||||
if (reshape->shape().is_dynamic_dimension(i)) {
|
||||
output_dynamic_dimension = i;
|
||||
}
|
||||
|
@ -1277,5 +1277,31 @@ TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
|
||||
auto six = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
|
||||
input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
|
||||
ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
|
||||
auto one = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
|
||||
ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
|
||||
|
||||
// Reshape [<=9, <=2] into [<=9, 1, <=2]
|
||||
|
||||
auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(RunInference());
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user