Call TopKWithPartition when input is large.

PiperOrigin-RevId: 328882175
Change-Id: I0c53e957d4e8c83e5c8598cc52748b6a9219088d
This commit is contained in:
A. Unique TensorFlower 2020-08-27 22:27:50 -07:00 committed by TensorFlower Gardener
parent 17a43be51d
commit 37a8445ebd
3 changed files with 90 additions and 19 deletions

View File

@ -346,6 +346,9 @@ cc_library(
hdrs = ["sorting.h"],
deps = [
":comparators",
":constants",
":loops",
":slicing",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",

View File

@ -16,6 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@ -27,6 +30,19 @@ XlaOp TopK(XlaOp input, int64 k) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
int last_dim = input_shape.dimensions_size() - 1;
int64 last_dim_size = input_shape.dimensions(last_dim);
// TODO(b/165839365): tune these constants for better performance.
int64 kPerPartitionSize = 8192; // 2^13
int64 kLastDimSizeThreshold = 524288; // 2^19
int64 kMinNumPartitions = 8;
if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) &&
last_dim_size >= kLastDimSizeThreshold) {
int64 num_partitions =
CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
if (num_partitions >= kMinNumPartitions) {
return TopKWithPartitions(input, k, num_partitions);
}
}
Shape iota_shape =
ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
@ -80,30 +96,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
}
}
XlaOp values, indices;
for (int64 partition = 0; partition < num_partitions; partition++) {
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
std::vector<int64> strides(input_shape.dimensions_size(), 1);
start_indices[last_dim] = partition * per_partition_size;
limit_indices[last_dim] =
std::min((partition + 1) * per_partition_size, last_dim_size);
// Slice value and indices for this partition..
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
auto topk_body_fn =
[&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto values = values_and_indices[0];
auto indices = values_and_indices[1];
auto input = values_and_indices[2];
auto iota_s32 = values_and_indices[3];
// Slice value and indices for this partition.
XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
ConstantR0<int32>(builder, per_partition_size));
XlaOp sliced_input =
DynamicSliceInMinorDims(input, {start}, {per_partition_size});
XlaOp sliced_indices =
Slice(iota_s32, start_indices, limit_indices, strides);
DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
// Concat with previous results.
if (partition > 0) {
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
sliced_indices =
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
}
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
sliced_indices =
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
// Sort this slice
XlaOp sort_result =
Sort({sliced_input, sliced_indices},
CreateScalarGtComputation({input_shape.element_type(), S32},
sliced_indices.builder()),
last_dim, /*is_stable=*/true);
last_dim, true);
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
std::vector<int64> strides(input_shape.dimensions_size(), 1);
// Slice topk.
start_indices[last_dim] = 0;
limit_indices[last_dim] = k;
@ -111,8 +132,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
limit_indices, strides);
indices = Slice(GetTupleElement(sort_result, 1), start_indices,
limit_indices, strides);
}
return Tuple(builder, {values, indices});
return std::vector<XlaOp>{values, indices, input, iota_s32};
};
// Get the values and indices for the first topk so that they can
// be passed to the while loop.
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
std::vector<int64> strides(input_shape.dimensions_size(), 1);
start_indices[last_dim] = 0;
limit_indices[last_dim] = per_partition_size;
// Slice value and indices for the first partition.
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
XlaOp sliced_indices =
Slice(iota_s32, start_indices, limit_indices, strides);
// Sort this slice
XlaOp sort_result =
Sort({sliced_input, sliced_indices},
CreateScalarGtComputation({input_shape.element_type(), S32},
sliced_indices.builder()),
last_dim, /*is_stable=*/true);
// Slice topk.
start_indices[last_dim] = 0;
limit_indices[last_dim] = k;
XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
limit_indices, strides);
XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
limit_indices, strides);
// Pass the result of the first TopK to the while loop and do
// num_partition - 1 iterations.
TF_ASSIGN_OR_RETURN(auto values_and_indices,
ForEachIndex(num_partitions - 1, S32, topk_body_fn,
{values, indices, input, iota_s32},
"topk_with_partition", builder));
return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
});
}

View File

@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) {
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
}
XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) {
XlaBuilder builder(TestName());
Array<float> input({2, 1000000});
input.FillRandom(1.0f, 2.0f);
auto x =
CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder);
Array2D<float> expected_array(2, 1000);
expected_array.Fill(2.0f);
xla::GetTupleElement(xla::TopK(x, 1000), 0);
ErrorSpec error_spec(10.0f, 10.0f);
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec);
}
XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) {
XlaBuilder builder(TestName());
auto x_rev =