Increase threshold for K to triggered TopKWithPartition.
PiperOrigin-RevId: 330853650 Change-Id: Ia515202a5194b8a9b881beaa4577462dc1cb05c7
This commit is contained in:
parent
9645e47535
commit
a9104043a8
@ -31,12 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
int last_dim = input_shape.dimensions_size() - 1;
|
||||
int64 last_dim_size = input_shape.dimensions(last_dim);
|
||||
// TODO(b/165839365): tune these constants for better performance.
|
||||
int64 kPerPartitionSize = 8192; // 2^13
|
||||
int64 kLastDimSizeThreshold = 524288; // 2^19
|
||||
int64 kMinNumPartitions = 8;
|
||||
if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) &&
|
||||
last_dim_size >= kLastDimSizeThreshold) {
|
||||
// TODO(b/148796364): tune these constants for better performance.
|
||||
const int64 kPerPartitionSize = 8192; // 2^13
|
||||
const int64 kLastDimSizeThreshold = 524288; // 2^19
|
||||
const int64 kMinNumPartitions = 8;
|
||||
const int64 kMinimalK = 1000;
|
||||
if ((k >= kMinimalK) && (k < kPerPartitionSize) &&
|
||||
(kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) {
|
||||
int64 num_partitions =
|
||||
CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
|
||||
if (num_partitions >= kMinNumPartitions) {
|
||||
|
Loading…
Reference in New Issue
Block a user