[XLA:CLIENT] Add batch dimensions to TorchIndexSelect

PiperOrigin-RevId: 248150697
This commit is contained in:
Blake Hechtman 2019-05-14 09:26:11 -07:00 committed by TensorFlower Gardener
parent 78830556b9
commit ae6df76700
4 changed files with 55 additions and 8 deletions

View File

@ -296,6 +296,7 @@ cc_library(
hdrs = ["slicing.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:xla_builder",
"@com_google_absl//absl/types:span",
],

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@ -161,22 +163,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) {
});
}
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) {
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
if (dim < batch_dims) {
return InvalidArgument(
"Gather dim must be greater than or equal to the number of batch "
"dims");
}
std::vector<int64> slice_sizes = input_shape.dimensions();
slice_sizes[dim] = 1;
GatherDimensionNumbers gather_dnums;
for (int64 i = 0; i < input_shape.rank(); ++i) {
if (i != dim) {
gather_dnums.add_offset_dims(i);
}
}
gather_dnums.set_index_vector_dim(index_shape.rank());
gather_dnums.add_collapsed_slice_dims(dim);
gather_dnums.add_start_index_map(dim);
if (batch_dims > 0) {
ShapeUtil::AppendMajorDimension(1, &index_shape);
std::vector<XlaOp> to_concat;
to_concat.reserve(batch_dims + 1);
for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
to_concat.push_back(Iota(builder, index_shape, batch_dim));
}
to_concat.push_back(Reshape(index, index_shape.dimensions()));
index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
}
for (int64 i = 0; i < input_shape.rank(); ++i) {
if (i < batch_dims || i == dim) {
slice_sizes[i] = 1;
gather_dnums.add_collapsed_slice_dims(i);
gather_dnums.add_start_index_map(i);
} else {
if (i < dim) {
gather_dnums.add_offset_dims(i);
} else {
gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
(1 + batch_dims));
}
}
}
return Gather(input, index, gather_dnums, slice_sizes);
});
}

View File

@ -63,7 +63,11 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim);
// The returned tensor has the same number of dimensions as the original tensor
// (input). The dimth dimension has the same size as the length of index; other
// dimensions have the same size as in the original tensor.
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim);
//
// This operation supports 0 or more major batch dimensions that act like a
// multidimensional loop over both the input and the index.
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim,
int64 batch_dims = 0);
} // namespace xla

View File

@ -146,6 +146,7 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
0, "input", &builder, &input);
auto index_data =
CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
TorchIndexSelect(input, index, 1);
ComputeAndCompareR2<float>(
@ -153,5 +154,23 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
{input_data.get(), index_data.get()});
}
XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) {
xla::XlaBuilder builder(TestName());
xla::XlaOp input, index;
auto input_data =
CreateR3Parameter<int>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{3, 2, 1, 0}, {7, 6, 5, 4}, {11, 10, 9, 8}}},
0, "input", &builder, &input);
auto index_data =
CreateR2Parameter<int>({{0, 2}, {1, 2}}, 1, "index", &builder, &index);
TorchIndexSelect(input, index, 1, 1);
ComputeAndCompareR3<int>(
&builder,
{{{0, 1, 2, 3}, {8, 9, 10, 11}}, {{7, 6, 5, 4}, {11, 10, 9, 8}}},
{input_data.get(), index_data.get()});
}
} // namespace
} // namespace xla