STT-tensorflow/tensorflow/compiler/xla/tests/gather_operation_test.cc
Blake Hechtman b5aa5f3b2f [XLA] Turn gathers of effective scalars into broadcasts.
PiperOrigin-RevId: 327747629
Change-Id: I453a249e54e9d00407e022f2909906ed29ef8b85
2020-08-20 20:37:28 -07:00

791 lines
25 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
using absl::nullopt;
class GatherOperationTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
Literal* start_indices) {
RunTest(hlo_text, {operand, start_indices});
}
void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV1) {
const string hlo_text = R"(
HloModule TensorFlowGatherV1
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1, 3}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) {
// Reverse the middle dimension (dim 1).
const string hlo_text = R"(
HloModule BatchDimInMiddle
ENTRY main {
operand = s32[3, 2, 3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3, 1, 2, 3] gather(operand, indices),
offset_dims={0, 1, 3},
collapsed_slice_dims={},
start_index_map={1},
index_vector_dim=1,
slice_sizes={3, 1, 3}
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}},
{{7, 8, 9}, {10, 11, 12}},
{{13, 14, 15}, {16, 17, 18}}});
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 0});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
const string hlo_text = R"(
HloModule TensorFlowGatherV2
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
offset_dims={0},
collapsed_slice_dims={1},
start_index_map={1},
index_vector_dim=1,
slice_sizes={3, 1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
const string hlo_text = R"(
HloModule TensorFlowGatherMultipleBatchDims
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={1},
start_index_map={1},
index_vector_dim=2,
slice_sizes={3, 1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
const string hlo_text = R"(
HloModule TensorFlowGatherNdMultipleBatchDims
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
offset_dims={},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=2,
slice_sizes={1, 1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
const string hlo_text = R"(
HloModule TensorFlowGatherNdMultipleBatchDims
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,1,1,2] gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=2,
slice_sizes={1, 1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
const string hlo_text = R"(
HloModule TensorFlowGatherNd
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1,2}
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
const string hlo_text = R"(
HloModule TensorFlowGatherNd
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1,2}
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
const char* hlo_text = R"(
HloModule DynamicSlice
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
offset_dims={0,1},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
const char* hlo_text = R"(
HloModule TensorFlowGatherV1
ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1, 0}
}
)";
Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
// Out of bounds indices must not crash, and the indices in range should
// produce the same values across all backends.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
// The next 2 tests uses data types that require extra steps on some backends so
// only run them on known good backends.
#if defined(XLA_TEST_BACKEND_GPU) || defined(XLA_TEST_BACKEND_CPU) || \
defined(XLA_TEST_BACKEND_INTERPRETER)
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex64Bit) {
// Out of bounds indices must not crash, even when the value is of a type
// larger than needed to access all values in the input, and the indices
// produce the same values across all backends.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s64[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int64>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {21474836407, 1}, {1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TooSmallIndex8Bit) {
// Indices of a type too small to index all locations in gather should not
// fail.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[512, 512]{1,0} parameter(0)
indices = u8[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
Literal operand = LiteralUtil::MakeIdentityR2<int32>(512);
Literal start_indices = LiteralUtil::CreateR2<uint8>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {7, 1}, {1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
#endif
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
// Out of bounds indices must not crash, and the indices in range should
// produce the same values across all backends.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
// Negative indices must not crash, and the indices in range should produce
// the same values across all backends.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices =
LiteralUtil::CreateR2<int32>({{2, -1},
{2, 1},
{1, 1},
{-500, 1},
{static_cast<int32>(-2147483648), 1},
{1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
// Negative indices must not crash, and the indices in range should produce
// the same values across all backends.
const string hlo_text = R"(
HloModule BatchDynamicSlice
ENTRY main {
operand = u32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = u32[6,1,1]{2,1,0} gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1}
ROOT result = u32[6]{0} reshape(gather)
}
)";
Literal operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices =
LiteralUtil::CreateR2<int32>({{2, -1},
{2, 1},
{1, 1},
{-500, 1},
{static_cast<int32>(-2147483648), 1},
{1, 2}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
offset_dims={0,1,2},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=0,
slice_sizes={1,3,2}
}
)";
Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
Literal start_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
const char* hlo_text = R"(
HloModule ScalarResult
ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[] gather(operand, index),
offset_dims={},
collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=0,
slice_sizes={1}
}
)";
Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
Literal start_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
const string hlo_text = R"(
HloModule ZeroSizedResult
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[0] parameter(1)
ROOT gather = s32[0,3] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1, 3}
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
const string hlo_text = R"(
HloModule FusedTensorFlowGatherV2
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
offset_dims={0},
collapsed_slice_dims={1},
start_index_map={1},
index_vector_dim=1,
slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
const string hlo_text = R"(
HloModule FusedTensorFlowGatherMultipleBatchDims
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={1},
start_index_map={1},
index_vector_dim=2,
slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
const string hlo_text = R"(
HloModule FusedTensorFlowGatherNdMultipleBatchDims
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
offset_dims={},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=2,
slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
const string hlo_text = R"(
HloModule FusedTensorFlowGatherNd
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=1,
slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest,
FusedTensorFlowGatherNdNonDefaultIndexVectorDim) {
const string hlo_text = R"(
HloModule FusedTensorFlowGatherNd
ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0,1},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
const char* hlo_text = R"(
HloModule FusedDynamicSlice
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
offset_dims={0,1},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
const string hlo_text = R"(
HloModule FusedBatchDynamicSlice
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
offset_dims={1,2},
collapsed_slice_dims={},
start_index_map={0,1},
index_vector_dim=0,
slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, GatherFromScalar) {
const string hlo_text = R"(
HloModule GatherFromScalar
ENTRY main {
operand = f32[] parameter(0)
indices = s32[0]{0} parameter(1)
ROOT gather = f32[] gather(operand, indices),
offset_dims={},
collapsed_slice_dims={},
start_index_map={},
index_vector_dim=0,
slice_sizes={}
}
)";
Literal operand = LiteralUtil::CreateR0<float>(1);
Literal start_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) {
const string hlo_text = R"(
HloModule GatherFromScalar
ENTRY main {
operand = f32[1,1,1] parameter(0)
indices = s32[2,3,50] parameter(1)
ROOT gather = f32[1,2,50] gather(operand, indices),
offset_dims={0},
collapsed_slice_dims={0,1},
start_index_map={1,0,2},
index_vector_dim=1,
slice_sizes={1,1,1}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0}));
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
// Disabled on interpreter since ExecuteAsyncOnStream is not supported.
XLA_TEST_F(GatherClientLibraryTest,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) {
// We create this HLO, but using the XlaBuilder API.
//
// ENTRY main {
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// ROOT gather = s32[2,3] gather(operand, indices),
// offset_dims={1},
// collapsed_slice_dims={0},
// start_index_map={0},
// index_vector_dim=1,
// slice_sizes={1, 3}
// }
XlaBuilder builder("gather_basic");
Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto operand = Parameter(&builder, 0, operand_shape, "operand");
auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_collapsed_slice_dims(0);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
Gather(operand, indices, dim_numbers, {1, 3});
std::vector<int32> expected = {};
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> operand_arg,
client_->TransferToServer(
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
*execution_options.add_device_handles() = devices[0];
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build());
std::vector<xla::Client::XlaComputationInstance> computation_instances = {
{computation,
{operand_arg.get(), indices_arg.get()},
execution_options,
/*execution_profile=*/nullptr}};
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<xla::GlobalData>> result_data,
client_->ExecuteParallel(computation_instances));
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
client_->Transfer(*(result_data[0])));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
}
} // namespace
} // namespace xla