diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index a9590be52a7..750237c2000 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -56,4 +56,64 @@ XlaOp TopK(XlaOp input, int64 k) { }); } +XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + // Calculate per partition size. + auto input_dims = input_shape.dimensions(); + int64 last_dim_size = input_shape.dimensions(last_dim); + const int64 per_partition_size = CeilOfRatio(last_dim_size, num_partitions); + // Do normal TopK when per partition size is smaller than or equal to k. + if (k >= per_partition_size) { + return TopK(input, k); + } + + Shape iota_shape = + ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); + XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (input_shape.is_dynamic_dimension(i)) { + // Propagate dynamic dimension from inputs to iota. + iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i); + } + } + + XlaOp values, indices; + for (int64 partition = 0; partition < num_partitions; partition++) { + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); + start_indices[last_dim] = partition * per_partition_size; + limit_indices[last_dim] = + std::min((partition + 1) * per_partition_size, last_dim_size); + // Slice value and indices for this partition.. + XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + XlaOp sliced_indices = + Slice(iota_s32, start_indices, limit_indices, strides); + // Concat with previous results. + if (partition > 0) { + sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); + sliced_indices = + ConcatInDim(builder, {indices, sliced_indices}, last_dim); + } + // Sort this slice + XlaOp sort_result = + Sort({sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), S32}, + sliced_indices.builder()), + last_dim, /*is_stable=*/true); + // Slice topk. + start_indices[last_dim] = 0; + limit_indices[last_dim] = k; + values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); + indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + } + return Tuple(builder, {values, indices}); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h index b9dfafdd6f9..699c8afd0de 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.h +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -25,6 +25,10 @@ namespace xla { // Returns a tuple composed of the top `k` values and corresponding indices in // `input`. Output values are in descending order, from largest to smallest. XlaOp TopK(XlaOp input, int64 k); +// Split sort in TopK into smaller sorts. +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions = 1); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 3bba84d90d4..e01f6faf59e 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -76,5 +76,64 @@ XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { ComputeAndCompareR1(&builder, {2, 3, 0, 1, 4}, {a_data.get()}); } +XLA_TEST_F(SortingTest, TopK3From8Values2Partitions) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/2), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices2Partitions) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/2), + 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Values3Partitions) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/3), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices3Partitions) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/3), + 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x, 3, /*num_partitions=*/5), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/5), + 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates2Partitions) { + XlaBuilder builder(TestName()); + XlaOp a; + auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); + xla::GetTupleElement(xla::TopKWithPartitions(a, 3, /*num_partitions=*/2), 1); + ComputeAndCompareR1(&builder, {2, 3, 0}, {a_data.get()}); +} + } // namespace } // namespace xla