diff --git a/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc index 2c786b577fc..e373010408b 100644 --- a/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc +++ b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc @@ -65,6 +65,12 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { } } + if (updated_reduced_dimensions.empty()) { + std::unique_ptr reshape = + HloInstruction::CreateBitcast(reduce_shape, reduced_op); + return ReplaceWithNewInstruction(instr, std::move(reshape)); + } + HloInstruction *input_reshape = instr->parent()->AddInstruction( HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)); diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc index 92f558ee98d..f5031817818 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc @@ -69,6 +69,38 @@ ENTRY main { )"); } +TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) { + const char* hlo_text = R"( +HloModule ReduceWithDegenerateDimensions + +add { + accum = f32[] parameter(0) + op = f32[] parameter(1) + ROOT out = f32[] add(accum, op) +} + +ENTRY main { + input = f32[1,3,1,4,1,5,1] parameter(0) + zero = f32[] constant(0) + + ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + // Copy instruction is added after bitcast because of copy-insertion pass, + // so we check the entire hlo module to verify there is no reduce instruction + // in this case. + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] { +// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0) +// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input) +// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}}) + )"); +} + } // namespace } // namespace gpu } // namespace xla