Increase threshold for K to triggered TopKWithPartition.

PiperOrigin-RevId: 330853650
Change-Id: Ia515202a5194b8a9b881beaa4577462dc1cb05c7
This commit is contained in:
A. Unique TensorFlower 2020-09-09 20:14:22 -07:00 committed by TensorFlower Gardener
parent 9645e47535
commit a9104043a8

View File

@ -31,12 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
int last_dim = input_shape.dimensions_size() - 1;
int64 last_dim_size = input_shape.dimensions(last_dim);
// TODO(b/165839365): tune these constants for better performance.
int64 kPerPartitionSize = 8192; // 2^13
int64 kLastDimSizeThreshold = 524288; // 2^19
int64 kMinNumPartitions = 8;
if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) &&
last_dim_size >= kLastDimSizeThreshold) {
// TODO(b/148796364): tune these constants for better performance.
const int64 kPerPartitionSize = 8192; // 2^13
const int64 kLastDimSizeThreshold = 524288; // 2^19
const int64 kMinNumPartitions = 8;
const int64 kMinimalK = 1000;
if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
(kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
int64 num_partitions =
CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
if (num_partitions >= kMinNumPartitions) {