[XLA] Add a test/example to use gather to rearrange a dimension.

PiperOrigin-RevId: 304646840
Change-Id: Ic028f885ce5906cfddc2ae39db88e81fb8e3359c
This commit is contained in:
Yunxing Dai 2020-04-03 10:41:25 -07:00 committed by TensorFlower Gardener
parent da1caca48b
commit 7ddd300c98

View File

@ -64,6 +64,30 @@ ENTRY main {
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) {
// Reverse the middle dimension (dim 1).
const string hlo_text = R"(
HloModule BatchDimInMiddle
ENTRY main {
operand = s32[3, 2, 3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3, 1, 2, 3] gather(operand, indices),
offset_dims={0, 1, 3},
collapsed_slice_dims={},
start_index_map={1},
index_vector_dim=1,
slice_sizes={3, 1, 3}
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}},
{{7, 8, 9}, {10, 11, 12}},
{{13, 14, 15}, {16, 17, 18}}});
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 0});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
const string hlo_text = R"(
HloModule TensorFlowGatherV2