Merge pull request #46452 from wangsiyu:master

PiperOrigin-RevId: 353073762
Change-Id: I7b6a8d51eb0e2825d6f6c36667a5617803568749
This commit is contained in:
TensorFlower Gardener 2021-01-21 12:24:31 -08:00
commit 05ce78c1c5
2 changed files with 38 additions and 0 deletions

View File

@ -65,6 +65,12 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
}
}
if (updated_reduced_dimensions.empty()) {
std::unique_ptr<HloInstruction> reshape =
HloInstruction::CreateBitcast(reduce_shape, reduced_op);
return ReplaceWithNewInstruction(instr, std::move(reshape));
}
HloInstruction *input_reshape = instr->parent()->AddInstruction(
HloInstruction::CreateBitcast(canonical_input_shape, reduced_op));

View File

@ -69,6 +69,38 @@ ENTRY main {
)");
}
TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
const char* hlo_text = R"(
HloModule ReduceWithDegenerateDimensions
add {
accum = f32[] parameter(0)
op = f32[] parameter(1)
ROOT out = f32[] add(accum, op)
}
ENTRY main {
input = f32[1,3,1,4,1,5,1] parameter(0)
zero = f32[] constant(0)
ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
// Copy instruction is added after bitcast because of copy-insertion pass,
// so we check the entire hlo module to verify there is no reduce instruction
// in this case.
MatchOptimizedHloWithShapes(hlo_text,
R"(
// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
)");
}
} // namespace
} // namespace gpu
} // namespace xla