Call TopKWithPartition when input is large.
PiperOrigin-RevId: 328882175 Change-Id: I0c53e957d4e8c83e5c8598cc52748b6a9219088d
This commit is contained in:
parent
17a43be51d
commit
37a8445ebd
@ -346,6 +346,9 @@ cc_library(
|
||||
hdrs = ["sorting.h"],
|
||||
deps = [
|
||||
":comparators",
|
||||
":constants",
|
||||
":loops",
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -27,6 +30,19 @@ XlaOp TopK(XlaOp input, int64 k) {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
int last_dim = input_shape.dimensions_size() - 1;
|
||||
int64 last_dim_size = input_shape.dimensions(last_dim);
|
||||
// TODO(b/165839365): tune these constants for better performance.
|
||||
int64 kPerPartitionSize = 8192; // 2^13
|
||||
int64 kLastDimSizeThreshold = 524288; // 2^19
|
||||
int64 kMinNumPartitions = 8;
|
||||
if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) &&
|
||||
last_dim_size >= kLastDimSizeThreshold) {
|
||||
int64 num_partitions =
|
||||
CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
|
||||
if (num_partitions >= kMinNumPartitions) {
|
||||
return TopKWithPartitions(input, k, num_partitions);
|
||||
}
|
||||
}
|
||||
|
||||
Shape iota_shape =
|
||||
ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
|
||||
@ -80,30 +96,35 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
|
||||
}
|
||||
}
|
||||
|
||||
XlaOp values, indices;
|
||||
for (int64 partition = 0; partition < num_partitions; partition++) {
|
||||
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||
std::vector<int64> strides(input_shape.dimensions_size(), 1);
|
||||
start_indices[last_dim] = partition * per_partition_size;
|
||||
limit_indices[last_dim] =
|
||||
std::min((partition + 1) * per_partition_size, last_dim_size);
|
||||
// Slice value and indices for this partition..
|
||||
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
|
||||
auto topk_body_fn =
|
||||
[&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto values = values_and_indices[0];
|
||||
auto indices = values_and_indices[1];
|
||||
auto input = values_and_indices[2];
|
||||
auto iota_s32 = values_and_indices[3];
|
||||
|
||||
// Slice value and indices for this partition.
|
||||
XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
|
||||
ConstantR0<int32>(builder, per_partition_size));
|
||||
XlaOp sliced_input =
|
||||
DynamicSliceInMinorDims(input, {start}, {per_partition_size});
|
||||
XlaOp sliced_indices =
|
||||
Slice(iota_s32, start_indices, limit_indices, strides);
|
||||
DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
|
||||
// Concat with previous results.
|
||||
if (partition > 0) {
|
||||
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
|
||||
sliced_indices =
|
||||
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
|
||||
}
|
||||
sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
|
||||
sliced_indices =
|
||||
ConcatInDim(builder, {indices, sliced_indices}, last_dim);
|
||||
// Sort this slice
|
||||
XlaOp sort_result =
|
||||
Sort({sliced_input, sliced_indices},
|
||||
CreateScalarGtComputation({input_shape.element_type(), S32},
|
||||
sliced_indices.builder()),
|
||||
last_dim, /*is_stable=*/true);
|
||||
last_dim, true);
|
||||
|
||||
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||
std::vector<int64> strides(input_shape.dimensions_size(), 1);
|
||||
// Slice topk.
|
||||
start_indices[last_dim] = 0;
|
||||
limit_indices[last_dim] = k;
|
||||
@ -111,8 +132,42 @@ XlaOp TopKWithPartitions(XlaOp input, int64 k, int64 num_partitions) {
|
||||
limit_indices, strides);
|
||||
indices = Slice(GetTupleElement(sort_result, 1), start_indices,
|
||||
limit_indices, strides);
|
||||
}
|
||||
return Tuple(builder, {values, indices});
|
||||
return std::vector<XlaOp>{values, indices, input, iota_s32};
|
||||
};
|
||||
|
||||
// Get the values and indices for the first topk so that they can
|
||||
// be passed to the while loop.
|
||||
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||
std::vector<int64> strides(input_shape.dimensions_size(), 1);
|
||||
start_indices[last_dim] = 0;
|
||||
limit_indices[last_dim] = per_partition_size;
|
||||
// Slice value and indices for the first partition.
|
||||
XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
|
||||
XlaOp sliced_indices =
|
||||
Slice(iota_s32, start_indices, limit_indices, strides);
|
||||
// Sort this slice
|
||||
XlaOp sort_result =
|
||||
Sort({sliced_input, sliced_indices},
|
||||
CreateScalarGtComputation({input_shape.element_type(), S32},
|
||||
sliced_indices.builder()),
|
||||
last_dim, /*is_stable=*/true);
|
||||
|
||||
// Slice topk.
|
||||
start_indices[last_dim] = 0;
|
||||
limit_indices[last_dim] = k;
|
||||
XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
|
||||
limit_indices, strides);
|
||||
XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
|
||||
limit_indices, strides);
|
||||
|
||||
// Pass the result of the first TopK to the while loop and do
|
||||
// num_partition - 1 iterations.
|
||||
TF_ASSIGN_OR_RETURN(auto values_and_indices,
|
||||
ForEachIndex(num_partitions - 1, S32, topk_body_fn,
|
||||
{values, indices, input, iota_s32},
|
||||
"topk_with_partition", builder));
|
||||
return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -118,6 +118,19 @@ XLA_TEST_F(SortingTest, TopK3From8Values5Partitions) {
|
||||
ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array<float> input({2, 1000000});
|
||||
input.FillRandom(1.0f, 2.0f);
|
||||
auto x =
|
||||
CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder);
|
||||
Array2D<float> expected_array(2, 1000);
|
||||
expected_array.Fill(2.0f);
|
||||
xla::GetTupleElement(xla::TopK(x, 1000), 0);
|
||||
ErrorSpec error_spec(10.0f, 10.0f);
|
||||
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec);
|
||||
}
|
||||
|
||||
XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x_rev =
|
||||
|
Loading…
Reference in New Issue
Block a user