From 37a8445ebd2c09c31d72661ec4c08057951bc93e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Aug 2020 22:27:50 -0700 Subject: [PATCH] Call TopKWithPartition when input is large. PiperOrigin-RevId: 328882175 Change-Id: I0c53e957d4e8c83e5c8598cc52748b6a9219088d --- tensorflow/compiler/xla/client/lib/BUILD | 3 + tensorflow/compiler/xla/client/lib/sorting.cc | 93 +++++++++++++++---- .../compiler/xla/client/lib/sorting_test.cc | 13 +++ 3 files changed, 90 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f0b4e5e6c79..b2a18492c57 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -346,6 +346,9 @@ cc_library( hdrs = ["sorting.h"], deps = [ ":comparators", + ":constants", + ":loops", + ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 750237c2000..5a7a70192d1 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -27,6 +30,19 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; + int64 last_dim_size = input_shape.dimensions(last_dim); + // TODO(b/165839365): tune these constants for better performance. + int64 kPerPartitionSize = 8192; // 2^13 + int64 kLastDimSizeThreshold = 524288; // 2^19 + int64 kMinNumPartitions = 8; + if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) && + last_dim_size >= kLastDimSizeThreshold) { + int64 num_partitions = + CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); + if (num_partitions >= kMinNumPartitions) { + return TopKWithPartitions(input, k, num_partitions); + } + } Shape iota_shape = ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); @@ -80,30 +96,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { } } - XlaOp values, indices; - for (int64 partition = 0; partition < num_partitions; partition++) { - std::vector start_indices(input_shape.dimensions_size(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - std::vector strides(input_shape.dimensions_size(), 1); - start_indices[last_dim] = partition * per_partition_size; - limit_indices[last_dim] = - std::min((partition + 1) * per_partition_size, last_dim_size); - // Slice value and indices for this partition.. - XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + auto topk_body_fn = + [&](XlaOp partition, absl::Span values_and_indices, + XlaBuilder* builder) -> StatusOr> { + auto values = values_and_indices[0]; + auto indices = values_and_indices[1]; + auto input = values_and_indices[2]; + auto iota_s32 = values_and_indices[3]; + + // Slice value and indices for this partition. + XlaOp start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + XlaOp sliced_input = + DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = - Slice(iota_s32, start_indices, limit_indices, strides); + DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size}); // Concat with previous results. - if (partition > 0) { - sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); - sliced_indices = - ConcatInDim(builder, {indices, sliced_indices}, last_dim); - } + sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); + sliced_indices = + ConcatInDim(builder, {indices, sliced_indices}, last_dim); // Sort this slice XlaOp sort_result = Sort({sliced_input, sliced_indices}, CreateScalarGtComputation({input_shape.element_type(), S32}, sliced_indices.builder()), - last_dim, /*is_stable=*/true); + last_dim, true); + + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); // Slice topk. start_indices[last_dim] = 0; limit_indices[last_dim] = k; @@ -111,8 +132,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) { limit_indices, strides); indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); - } - return Tuple(builder, {values, indices}); + return std::vector{values, indices, input, iota_s32}; + }; + + // Get the values and indices for the first topk so that they can + // be passed to the while loop. + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + std::vector strides(input_shape.dimensions_size(), 1); + start_indices[last_dim] = 0; + limit_indices[last_dim] = per_partition_size; + // Slice value and indices for the first partition. + XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); + XlaOp sliced_indices = + Slice(iota_s32, start_indices, limit_indices, strides); + // Sort this slice + XlaOp sort_result = + Sort({sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), S32}, + sliced_indices.builder()), + last_dim, /*is_stable=*/true); + + // Slice topk. + start_indices[last_dim] = 0; + limit_indices[last_dim] = k; + XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + + // Pass the result of the first TopK to the while loop and do + // num_partition - 1 iterations. + TF_ASSIGN_OR_RETURN(auto values_and_indices, + ForEachIndex(num_partitions - 1, S32, topk_body_fn, + {values, indices, input, iota_s32}, + "topk_with_partition", builder)); + return Tuple(builder, {values_and_indices[0], values_and_indices[1]}); }); } diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index e01f6faf59e..e820d5bfe6f 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) { ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); } +XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) { + XlaBuilder builder(TestName()); + Array input({2, 1000000}); + input.FillRandom(1.0f, 2.0f); + auto x = + CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder); + Array2D expected_array(2, 1000); + expected_array.Fill(2.0f); + xla::GetTupleElement(xla::TopK(x, 1000), 0); + ErrorSpec error_spec(10.0f, 10.0f); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec); +} + XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { XlaBuilder builder(TestName()); auto x_rev =