Support TopK op to split sorting into smaller sorts and a big final sort.

PiperOrigin-RevId: 294729385
Change-Id: I2a757f3e66ec37b4e004c41e149d8e1afc7b63d1
This commit is contained in:
A. Unique TensorFlower 2020-02-12 12:35:47 -08:00 committed by TensorFlower Gardener
parent 0d204befa9
commit eca81d2593
3 changed files with 123 additions and 0 deletions

View File

@ -56,4 +56,64 @@ XlaOp TopK(XlaOp input, int64 k) {
});
}
XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
XlaBuilder* const builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
int last_dim = input_shape.dimensions_size() - 1;
// Calculate per partition size.
auto input_dims = input_shape.dimensions();
int64 last_dim_size = input_shape.dimensions(last_dim);
const int64 per_partition_size = CeilOfRatio(last_dim_size, num_partitions);
// Do normal TopK when per partition size is smaller than or equal to k.
if (k >= per_partition_size) {
return TopK(input, k);
}
Shape iota_shape =
ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
for (int64 i = 0; i < input_shape.rank(); ++i) {
if (input_shape.is_dynamic_dimension(i)) {
// Propagate dynamic dimension from inputs to iota.
iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
}
}
XlaOp values, indices;
for (int64 partition = 0; partition < num_partitions; partition++) {
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
std::vector<int64> strides(input_shape.dimensions_size(), 1);
start_indices[last_dim] = partition * per_partition_size;
limit_indices[last_dim] =
std::min((partition + 1) * per_partition_size, last_dim_size);
// Slice value and indices for this partition..
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
XlaOp sliced_indices =
Slice(iota_s32, start_indices, limit_indices, strides);
// Concat with previous results.
if (partition > 0) {
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
sliced_indices =
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
}
// Sort this slice
XlaOp sort_result =
Sort({sliced_input, sliced_indices},
CreateScalarGtComputation({input_shape.element_type(), S32},
sliced_indices.builder()),
last_dim, /*is_stable=*/true);
// Slice topk.
start_indices[last_dim] = 0;
limit_indices[last_dim] = k;
values = Slice(GetTupleElement(sort_result, 0), start_indices,
limit_indices, strides);
indices = Slice(GetTupleElement(sort_result, 1), start_indices,
limit_indices, strides);
}
return Tuple(builder, {values, indices});
});
}
} // namespace xla

View File

@ -25,6 +25,10 @@ namespace xla {
// Returns a tuple composed of the top `k` values and corresponding indices in
// `input`. Output values are in descending order, from largest to smallest.
XlaOp TopK(XlaOp input, int64 k);
// Split sort in TopK into smaller sorts.
// Returns a tuple composed of the top `k` values and corresponding indices in
// `input`. Output values are in descending order, from largest to smallest.
XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions = 1);
} // namespace xla

View File

@ -76,5 +76,64 @@ XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) {
ComputeAndCompareR1<int>(&builder, {2, 3, 0, 1, 4}, {a_data.get()});
}
XLA_TEST_F(SortingTest, TopK3From8Values2Partitions) {
XlaBuilder builder(TestName());
auto x =
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/2), 0);
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
}
XLA_TEST_F(SortingTest, TopK3From8Indices2Partitions) {
XlaBuilder builder(TestName());
auto x_rev =
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/2),
1);
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
}
XLA_TEST_F(SortingTest, TopK3From8Values3Partitions) {
XlaBuilder builder(TestName());
auto x =
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/3), 0);
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
}
XLA_TEST_F(SortingTest, TopK3From8Indices3Partitions) {
XlaBuilder builder(TestName());
auto x_rev =
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/3),
1);
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
}
XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) {
XlaBuilder builder(TestName());
auto x =
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/5), 0);
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
}
XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) {
XlaBuilder builder(TestName());
auto x_rev =
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/5),
1);
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
}
XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates2Partitions) {
XlaBuilder builder(TestName());
XlaOp a;
auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);
xla::GetTupleElement(xla::TopKWithPartitions(a, 3, /*num_partitions=*/2), 1);
ComputeAndCompareR1<int>(&builder, {2, 3, 0}, {a_data.get()});
}
} // namespace
} // namespace xla