Fix in_top_k on GPU with empty input.
PiperOrigin-RevId: 303341572 Change-Id: Idb5d93a462aa2bf0a6e5ab0ec9ec3b1aa39e9981
This commit is contained in:
parent
7e2f5ac786
commit
6ef1de968e
@ -100,6 +100,12 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Number of targets * number of classes must be less than INT_MAX"));
|
"Number of targets * number of classes must be less than INT_MAX"));
|
||||||
|
|
||||||
|
if (num_targets == 0 || num_classes == 0) {
|
||||||
|
// Result is empty, so shortcut the rest of the function to avoid
|
||||||
|
// launching kernels with empty input.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
|
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
|
||||||
Tensor predictions_mask;
|
Tensor predictions_mask;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
|
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
|||||||
class InTopKTest(test.TestCase):
|
class InTopKTest(test.TestCase):
|
||||||
|
|
||||||
def _validateInTopK(self, predictions, target, k, expected):
|
def _validateInTopK(self, predictions, target, k, expected):
|
||||||
np_ans = np.array(expected)
|
np_ans = np.array(expected, np.bool)
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
precision = nn_ops.in_top_k(predictions, target, k)
|
precision = nn_ops.in_top_k(predictions, target, k)
|
||||||
out = self.evaluate(precision)
|
out = self.evaluate(precision)
|
||||||
@ -66,6 +66,11 @@ class InTopKTest(test.TestCase):
|
|||||||
target = [2, 4] # must return False for invalid target
|
target = [2, 4] # must return False for invalid target
|
||||||
self._validateInTopK(predictions, target, 2, [True, False])
|
self._validateInTopK(predictions, target, 2, [True, False])
|
||||||
|
|
||||||
|
def testEmpty(self):
|
||||||
|
predictions = np.empty([0, 5])
|
||||||
|
target = np.empty([0], np.int32)
|
||||||
|
self._validateInTopK(predictions, target, 2, [])
|
||||||
|
|
||||||
def testTensorK(self):
|
def testTensorK(self):
|
||||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
target = [0, 2]
|
target = [0, 2]
|
||||||
|
Loading…
Reference in New Issue
Block a user