Fix in_top_k on GPU with empty input.
PiperOrigin-RevId: 303341572 Change-Id: Idb5d93a462aa2bf0a6e5ab0ec9ec3b1aa39e9981
This commit is contained in:
parent
7e2f5ac786
commit
6ef1de968e
@ -100,6 +100,12 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
||||
errors::InvalidArgument(
|
||||
"Number of targets * number of classes must be less than INT_MAX"));
|
||||
|
||||
if (num_targets == 0 || num_classes == 0) {
|
||||
// Result is empty, so shortcut the rest of the function to avoid
|
||||
// launching kernels with empty input.
|
||||
return;
|
||||
}
|
||||
|
||||
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
|
||||
Tensor predictions_mask;
|
||||
OP_REQUIRES_OK(
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
class InTopKTest(test.TestCase):
|
||||
|
||||
def _validateInTopK(self, predictions, target, k, expected):
|
||||
np_ans = np.array(expected)
|
||||
np_ans = np.array(expected, np.bool)
|
||||
with self.cached_session(use_gpu=True):
|
||||
precision = nn_ops.in_top_k(predictions, target, k)
|
||||
out = self.evaluate(precision)
|
||||
@ -66,6 +66,11 @@ class InTopKTest(test.TestCase):
|
||||
target = [2, 4] # must return False for invalid target
|
||||
self._validateInTopK(predictions, target, 2, [True, False])
|
||||
|
||||
def testEmpty(self):
|
||||
predictions = np.empty([0, 5])
|
||||
target = np.empty([0], np.int32)
|
||||
self._validateInTopK(predictions, target, 2, [])
|
||||
|
||||
def testTensorK(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||
target = [0, 2]
|
||||
|
Loading…
Reference in New Issue
Block a user