[XLA:CLIENT] Add batch dimensions to TorchIndexSelect
PiperOrigin-RevId: 248150697
This commit is contained in:
parent
78830556b9
commit
ae6df76700
@ -296,6 +296,7 @@ cc_library(
|
||||
hdrs = ["slicing.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -161,22 +163,43 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) {
|
||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
|
||||
if (dim < batch_dims) {
|
||||
return InvalidArgument(
|
||||
"Gather dim must be greater than or equal to the number of batch "
|
||||
"dims");
|
||||
}
|
||||
std::vector<int64> slice_sizes = input_shape.dimensions();
|
||||
slice_sizes[dim] = 1;
|
||||
GatherDimensionNumbers gather_dnums;
|
||||
for (int64 i = 0; i < input_shape.rank(); ++i) {
|
||||
if (i != dim) {
|
||||
gather_dnums.add_offset_dims(i);
|
||||
}
|
||||
}
|
||||
gather_dnums.set_index_vector_dim(index_shape.rank());
|
||||
gather_dnums.add_collapsed_slice_dims(dim);
|
||||
gather_dnums.add_start_index_map(dim);
|
||||
if (batch_dims > 0) {
|
||||
ShapeUtil::AppendMajorDimension(1, &index_shape);
|
||||
std::vector<XlaOp> to_concat;
|
||||
to_concat.reserve(batch_dims + 1);
|
||||
for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
|
||||
to_concat.push_back(Iota(builder, index_shape, batch_dim));
|
||||
}
|
||||
to_concat.push_back(Reshape(index, index_shape.dimensions()));
|
||||
index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
|
||||
}
|
||||
for (int64 i = 0; i < input_shape.rank(); ++i) {
|
||||
if (i < batch_dims || i == dim) {
|
||||
slice_sizes[i] = 1;
|
||||
gather_dnums.add_collapsed_slice_dims(i);
|
||||
gather_dnums.add_start_index_map(i);
|
||||
} else {
|
||||
if (i < dim) {
|
||||
gather_dnums.add_offset_dims(i);
|
||||
} else {
|
||||
gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
|
||||
(1 + batch_dims));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Gather(input, index, gather_dnums, slice_sizes);
|
||||
});
|
||||
}
|
||||
|
@ -63,7 +63,11 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim);
|
||||
// The returned tensor has the same number of dimensions as the original tensor
|
||||
// (input). The dimth dimension has the same size as the length of index; other
|
||||
// dimensions have the same size as in the original tensor.
|
||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim);
|
||||
//
|
||||
// This operation supports 0 or more major batch dimensions that act like a
|
||||
// multidimensional loop over both the input and the index.
|
||||
XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim,
|
||||
int64 batch_dims = 0);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -146,6 +146,7 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
|
||||
0, "input", &builder, &input);
|
||||
auto index_data =
|
||||
CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
|
||||
|
||||
TorchIndexSelect(input, index, 1);
|
||||
|
||||
ComputeAndCompareR2<float>(
|
||||
@ -153,5 +154,23 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
|
||||
{input_data.get(), index_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::XlaOp input, index;
|
||||
auto input_data =
|
||||
CreateR3Parameter<int>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
|
||||
{{3, 2, 1, 0}, {7, 6, 5, 4}, {11, 10, 9, 8}}},
|
||||
0, "input", &builder, &input);
|
||||
auto index_data =
|
||||
CreateR2Parameter<int>({{0, 2}, {1, 2}}, 1, "index", &builder, &index);
|
||||
TorchIndexSelect(input, index, 1, 1);
|
||||
|
||||
ComputeAndCompareR3<int>(
|
||||
&builder,
|
||||
{{{0, 1, 2, 3}, {8, 9, 10, 11}}, {{7, 6, 5, 4}, {11, 10, 9, 8}}},
|
||||
{input_data.get(), index_data.get()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user