Support TopK op to split sorting into smaller sorts and a big final sort.
PiperOrigin-RevId: 294729385 Change-Id: I2a757f3e66ec37b4e004c41e149d8e1afc7b63d1
This commit is contained in:
parent
0d204befa9
commit
eca81d2593
@ -56,4 +56,64 @@ XlaOp TopK(XlaOp input, int64 k) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
|
||||
XlaBuilder* const builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
int last_dim = input_shape.dimensions_size() - 1;
|
||||
// Calculate per partition size.
|
||||
auto input_dims = input_shape.dimensions();
|
||||
int64 last_dim_size = input_shape.dimensions(last_dim);
|
||||
const int64 per_partition_size = CeilOfRatio(last_dim_size, num_partitions);
|
||||
// Do normal TopK when per partition size is smaller than or equal to k.
|
||||
if (k >= per_partition_size) {
|
||||
return TopK(input, k);
|
||||
}
|
||||
|
||||
Shape iota_shape =
|
||||
ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
|
||||
XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
|
||||
for (int64 i = 0; i < input_shape.rank(); ++i) {
|
||||
if (input_shape.is_dynamic_dimension(i)) {
|
||||
// Propagate dynamic dimension from inputs to iota.
|
||||
iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i);
|
||||
}
|
||||
}
|
||||
|
||||
XlaOp values, indices;
|
||||
for (int64 partition = 0; partition < num_partitions; partition++) {
|
||||
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||
std::vector<int64> strides(input_shape.dimensions_size(), 1);
|
||||
start_indices[last_dim] = partition * per_partition_size;
|
||||
limit_indices[last_dim] =
|
||||
std::min((partition + 1) * per_partition_size, last_dim_size);
|
||||
// Slice value and indices for this partition..
|
||||
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
|
||||
XlaOp sliced_indices =
|
||||
Slice(iota_s32, start_indices, limit_indices, strides);
|
||||
// Concat with previous results.
|
||||
if (partition > 0) {
|
||||
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
|
||||
sliced_indices =
|
||||
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
|
||||
}
|
||||
// Sort this slice
|
||||
XlaOp sort_result =
|
||||
Sort({sliced_input, sliced_indices},
|
||||
CreateScalarGtComputation({input_shape.element_type(), S32},
|
||||
sliced_indices.builder()),
|
||||
last_dim, /*is_stable=*/true);
|
||||
// Slice topk.
|
||||
start_indices[last_dim] = 0;
|
||||
limit_indices[last_dim] = k;
|
||||
values = Slice(GetTupleElement(sort_result, 0), start_indices,
|
||||
limit_indices, strides);
|
||||
indices = Slice(GetTupleElement(sort_result, 1), start_indices,
|
||||
limit_indices, strides);
|
||||
}
|
||||
return Tuple(builder, {values, indices});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -25,6 +25,10 @@ namespace xla {
|
||||
// Returns a tuple composed of the top `k` values and corresponding indices in
|
||||
// `input`. Output values are in descending order, from largest to smallest.
|
||||
XlaOp TopK(XlaOp input, int64 k);
|
||||
// Split sort in TopK into smaller sorts.
|
||||
// Returns a tuple composed of the top `k` values and corresponding indices in
|
||||
// `input`. Output values are in descending order, from largest to smallest.
|
||||
XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions = 1);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
||||
@ -76,5 +76,64 @@ XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) {
|
||||
ComputeAndCompareR1<int>(&builder, {2, 3, 0, 1, 4}, {a_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Values2Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x =
|
||||
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/2), 0);
|
||||
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Indices2Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x_rev =
|
||||
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/2),
|
||||
1);
|
||||
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Values3Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x =
|
||||
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/3), 0);
|
||||
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Indices3Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x_rev =
|
||||
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/3),
|
||||
1);
|
||||
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x =
|
||||
ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/5), 0);
|
||||
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x_rev =
|
||||
ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/5),
|
||||
1);
|
||||
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates2Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp a;
|
||||
auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);
|
||||
xla::GetTupleElement(xla::TopKWithPartitions(a, 3, /*num_partitions=*/2), 1);
|
||||
ComputeAndCompareR1<int>(&builder, {2, 3, 0}, {a_data.get()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user