[XLA] Turn gathers of effective scalars into broadcasts.

PiperOrigin-RevId: 327747629
Change-Id: I453a249e54e9d00407e022f2909906ed29ef8b85
This commit is contained in:
Blake Hechtman 2020-08-20 20:33:40 -07:00 committed by TensorFlower Gardener
parent cb0e3c6e6d
commit b5aa5f3b2f
3 changed files with 56 additions and 0 deletions

View File

@ -2500,6 +2500,20 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
if (ShapeUtil::IsZeroElementArray(operand_shape)) { if (ShapeUtil::IsZeroElementArray(operand_shape)) {
return ReplaceInstruction(gather, MakeScalarLike(gather, 0)); return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
} }
// Gathering from a scalar operand is simply a broadcast of that scalar
if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
HloInstruction* new_operand = gather->mutable_operand(0);
if (operand_shape.rank()) {
TF_ASSIGN_OR_RETURN(new_operand,
MakeReshapeHlo(ShapeUtil::MakeScalarShape(
operand_shape.element_type()),
new_operand));
}
HloInstruction* new_gather =
MakeBroadcastHlo(new_operand, {}, gather->shape());
return ReplaceInstruction(gather, new_gather);
}
// If the operand of a gather is very small, it is easier to fuse a // If the operand of a gather is very small, it is easier to fuse a
// sequence of selects. // sequence of selects.
const Shape& index_shape = gather->operand(1)->shape(); const Shape& index_shape = gather->operand(1)->shape();

View File

@ -5647,6 +5647,30 @@ INSTANTIATE_TEST_SUITE_P(
DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
const char* hlo_string = R"(
HloModule repeat
ENTRY main {
o = f32[1,1] parameter(0)
i = s32[100,2] parameter(1)
ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=1,
offset_dims={},
slice_sizes={1,1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
const char* hlo_string = R"( const char* hlo_string = R"(
HloModule module HloModule module

View File

@ -711,6 +711,24 @@ ENTRY main {
RunTest(hlo_text, &operand, &start_indices); RunTest(hlo_text, &operand, &start_indices);
} }
XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) {
const string hlo_text = R"(
HloModule GatherFromScalar
ENTRY main {
operand = f32[1,1,1] parameter(0)
indices = s32[2,3,50] parameter(1)
ROOT gather = f32[1,2,50] gather(operand, indices),
offset_dims={0},
collapsed_slice_dims={0,1},
start_index_map={1,0,2},
index_vector_dim=1,
slice_sizes={1,1,1}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0}));
}
class GatherClientLibraryTest : public ClientLibraryTestBase {}; class GatherClientLibraryTest : public ClientLibraryTestBase {};
// Disabled on interpreter since ExecuteAsyncOnStream is not supported. // Disabled on interpreter since ExecuteAsyncOnStream is not supported.