diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 9a19427a96a..0fd5f191db0 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -64,6 +64,30 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) { + // Reverse the middle dimension (dim 1). + const string hlo_text = R"( +HloModule BatchDimInMiddle + +ENTRY main { + operand = s32[3, 2, 3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3, 1, 2, 3] gather(operand, indices), + offset_dims={0, 1, 3}, + collapsed_slice_dims={}, + start_index_map={1}, + index_vector_dim=1, + slice_sizes={3, 1, 3} +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}, + {{7, 8, 9}, {10, 11, 12}}, + {{13, 14, 15}, {16, 17, 18}}}); + Literal start_indices = LiteralUtil::CreateR1({1, 0}); + RunTest(hlo_text, &operand, &start_indices); +} + XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { const string hlo_text = R"( HloModule TensorFlowGatherV2