[XLA] Turn gathers of effective scalars into broadcasts.
PiperOrigin-RevId: 327747629 Change-Id: I453a249e54e9d00407e022f2909906ed29ef8b85
This commit is contained in:
parent
cb0e3c6e6d
commit
b5aa5f3b2f
@ -2500,6 +2500,20 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
|
|||||||
if (ShapeUtil::IsZeroElementArray(operand_shape)) {
|
if (ShapeUtil::IsZeroElementArray(operand_shape)) {
|
||||||
return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
|
return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gathering from a scalar operand is simply a broadcast of that scalar
|
||||||
|
if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
|
||||||
|
HloInstruction* new_operand = gather->mutable_operand(0);
|
||||||
|
if (operand_shape.rank()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(new_operand,
|
||||||
|
MakeReshapeHlo(ShapeUtil::MakeScalarShape(
|
||||||
|
operand_shape.element_type()),
|
||||||
|
new_operand));
|
||||||
|
}
|
||||||
|
HloInstruction* new_gather =
|
||||||
|
MakeBroadcastHlo(new_operand, {}, gather->shape());
|
||||||
|
return ReplaceInstruction(gather, new_gather);
|
||||||
|
}
|
||||||
// If the operand of a gather is very small, it is easier to fuse a
|
// If the operand of a gather is very small, it is easier to fuse a
|
||||||
// sequence of selects.
|
// sequence of selects.
|
||||||
const Shape& index_shape = gather->operand(1)->shape();
|
const Shape& index_shape = gather->operand(1)->shape();
|
||||||
|
@ -5647,6 +5647,30 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
|
DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
|
||||||
::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
|
::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule repeat
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
o = f32[1,1] parameter(0)
|
||||||
|
i = s32[100,2] parameter(1)
|
||||||
|
ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
|
||||||
|
start_index_map={0,1},
|
||||||
|
index_vector_dim=1,
|
||||||
|
offset_dims={},
|
||||||
|
slice_sizes={1,1}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
AlgebraicSimplifierOptions options;
|
||||||
|
AlgebraicSimplifier simplifier(options);
|
||||||
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
|
TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
|
||||||
const char* hlo_string = R"(
|
const char* hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
@ -711,6 +711,24 @@ ENTRY main {
|
|||||||
RunTest(hlo_text, &operand, &start_indices);
|
RunTest(hlo_text, &operand, &start_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule GatherFromScalar
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = f32[1,1,1] parameter(0)
|
||||||
|
indices = s32[2,3,50] parameter(1)
|
||||||
|
ROOT gather = f32[1,2,50] gather(operand, indices),
|
||||||
|
offset_dims={0},
|
||||||
|
collapsed_slice_dims={0,1},
|
||||||
|
start_index_map={1,0,2},
|
||||||
|
index_vector_dim=1,
|
||||||
|
slice_sizes={1,1,1}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
||||||
|
|
||||||
// Disabled on interpreter since ExecuteAsyncOnStream is not supported.
|
// Disabled on interpreter since ExecuteAsyncOnStream is not supported.
|
||||||
|
Loading…
Reference in New Issue
Block a user