[XLA] Add a test/example to use gather to rearrange a dimension.
PiperOrigin-RevId: 304646840 Change-Id: Ic028f885ce5906cfddc2ae39db88e81fb8e3359c
This commit is contained in:
parent
da1caca48b
commit
7ddd300c98
@ -64,6 +64,30 @@ ENTRY main {
|
||||
RunTest(hlo_text, &operand, &start_indices);
|
||||
}
|
||||
|
||||
XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) {
|
||||
// Reverse the middle dimension (dim 1).
|
||||
const string hlo_text = R"(
|
||||
HloModule BatchDimInMiddle
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3, 2, 3] parameter(0)
|
||||
indices = s32[2] parameter(1)
|
||||
ROOT gather = s32[3, 1, 2, 3] gather(operand, indices),
|
||||
offset_dims={0, 1, 3},
|
||||
collapsed_slice_dims={},
|
||||
start_index_map={1},
|
||||
index_vector_dim=1,
|
||||
slice_sizes={3, 1, 3}
|
||||
}
|
||||
)";
|
||||
Literal operand =
|
||||
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}},
|
||||
{{7, 8, 9}, {10, 11, 12}},
|
||||
{{13, 14, 15}, {16, 17, 18}}});
|
||||
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 0});
|
||||
RunTest(hlo_text, &operand, &start_indices);
|
||||
}
|
||||
|
||||
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowGatherV2
|
||||
|
Loading…
Reference in New Issue
Block a user