[XLA] Add a test/example to use gather to rearrange a dimension.
PiperOrigin-RevId: 304646840 Change-Id: Ic028f885ce5906cfddc2ae39db88e81fb8e3359c
This commit is contained in:
parent
da1caca48b
commit
7ddd300c98
@ -64,6 +64,30 @@ ENTRY main {
|
|||||||
RunTest(hlo_text, &operand, &start_indices);
|
RunTest(hlo_text, &operand, &start_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) {
|
||||||
|
// Reverse the middle dimension (dim 1).
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule BatchDimInMiddle
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = s32[3, 2, 3] parameter(0)
|
||||||
|
indices = s32[2] parameter(1)
|
||||||
|
ROOT gather = s32[3, 1, 2, 3] gather(operand, indices),
|
||||||
|
offset_dims={0, 1, 3},
|
||||||
|
collapsed_slice_dims={},
|
||||||
|
start_index_map={1},
|
||||||
|
index_vector_dim=1,
|
||||||
|
slice_sizes={3, 1, 3}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
Literal operand =
|
||||||
|
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}},
|
||||||
|
{{7, 8, 9}, {10, 11, 12}},
|
||||||
|
{{13, 14, 15}, {16, 17, 18}}});
|
||||||
|
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 0});
|
||||||
|
RunTest(hlo_text, &operand, &start_indices);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
|
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
|
||||||
const string hlo_text = R"(
|
const string hlo_text = R"(
|
||||||
HloModule TensorFlowGatherV2
|
HloModule TensorFlowGatherV2
|
||||||
|
Loading…
Reference in New Issue
Block a user