Fix in_top_k on GPU with empty input.

PiperOrigin-RevId: 303341572
Change-Id: Idb5d93a462aa2bf0a6e5ab0ec9ec3b1aa39e9981
This commit is contained in:
A. Unique TensorFlower 2020-03-27 09:09:42 -07:00 committed by TensorFlower Gardener
parent 7e2f5ac786
commit 6ef1de968e
2 changed files with 12 additions and 1 deletions

View File

@ -100,6 +100,12 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
errors::InvalidArgument( errors::InvalidArgument(
"Number of targets * number of classes must be less than INT_MAX")); "Number of targets * number of classes must be less than INT_MAX"));
if (num_targets == 0 || num_classes == 0) {
// Result is empty, so shortcut the rest of the function to avoid
// launching kernels with empty input.
return;
}
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`. // Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
Tensor predictions_mask; Tensor predictions_mask;
OP_REQUIRES_OK( OP_REQUIRES_OK(

View File

@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class InTopKTest(test.TestCase): class InTopKTest(test.TestCase):
def _validateInTopK(self, predictions, target, k, expected): def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected) np_ans = np.array(expected, np.bool)
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
precision = nn_ops.in_top_k(predictions, target, k) precision = nn_ops.in_top_k(predictions, target, k)
out = self.evaluate(precision) out = self.evaluate(precision)
@ -66,6 +66,11 @@ class InTopKTest(test.TestCase):
target = [2, 4] # must return False for invalid target target = [2, 4] # must return False for invalid target
self._validateInTopK(predictions, target, 2, [True, False]) self._validateInTopK(predictions, target, 2, [True, False])
def testEmpty(self):
predictions = np.empty([0, 5])
target = np.empty([0], np.int32)
self._validateInTopK(predictions, target, 2, [])
def testTensorK(self): def testTensorK(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 2] target = [0, 2]