Increase threshold for K to triggered TopKWithPartition.
PiperOrigin-RevId: 330853650 Change-Id: Ia515202a5194b8a9b881beaa4577462dc1cb05c7
This commit is contained in:
		
							parent
							
								
									9645e47535
								
							
						
					
					
						commit
						a9104043a8
					
				| @ -31,12 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) { | ||||
|     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); | ||||
|     int last_dim = input_shape.dimensions_size() - 1; | ||||
|     int64 last_dim_size = input_shape.dimensions(last_dim); | ||||
|     // TODO(b/165839365): tune these constants for better performance.
 | ||||
|     int64 kPerPartitionSize = 8192;        // 2^13
 | ||||
|     int64 kLastDimSizeThreshold = 524288;  // 2^19
 | ||||
|     int64 kMinNumPartitions = 8; | ||||
|     if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) && | ||||
|         last_dim_size >= kLastDimSizeThreshold) { | ||||
|     // TODO(b/148796364): tune these constants for better performance.
 | ||||
|     const int64 kPerPartitionSize = 8192;        // 2^13
 | ||||
|     const int64 kLastDimSizeThreshold = 524288;  // 2^19
 | ||||
|     const int64 kMinNumPartitions = 8; | ||||
|     const int64 kMinimalK = 1000; | ||||
|     if ((k >= kMinimalK) && (k < kPerPartitionSize) && | ||||
|         (kPerPartitionSize / k > 2) && last_dim_size >= kLastDimSizeThreshold) { | ||||
|       int64 num_partitions = | ||||
|           CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); | ||||
|       if (num_partitions >= kMinNumPartitions) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user