Set the indices in TopK op for an edge case where k=1 and input is all NaN.

The returned indices aren't set in this case because input.maximum() returns
-Inf and NaN != -Inf. Even if input.maximum() returns NaN, the indices still
won't be set due to NaN != NaN. This behavior surprises users and is likely to
cause index-out-of-range error in downstream operations.

PiperOrigin-RevId: 257497013
This commit is contained in:
A. Unique TensorFlower 2019-07-10 15:42:27 -07:00 committed by TensorFlower Gardener
parent 943a1757d5
commit 1c94a5bed5
2 changed files with 6 additions and 0 deletions

View File

@ -123,12 +123,14 @@ struct TopKFunctor<CPUDevice, T> {
input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
// Get the indices of the maximum values. // Get the indices of the maximum values.
for (int r = 0; r < num_rows; ++r) { for (int r = 0; r < num_rows; ++r) {
indices(r, 0) = 0;
for (int c = 0; c < num_cols; ++c) { for (int c = 0; c < num_cols; ++c) {
if (values(r, 0) == input(r, c)) { if (values(r, 0) == input(r, c)) {
indices(r, 0) = c; indices(r, 0) = c;
break; break;
} }
} }
values(r, 0) = input(r, indices(r, 0));
} }
return Status::OK(); return Status::OK();

View File

@ -108,6 +108,10 @@ class TopKTest(test.TestCase):
values = -np.sort(-inputs)[:k] values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices) self._validateTopK(inputs, k, values, indices)
def testTop1AllNan(self):
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]
self._validateTopK(inputs, 1, [[np.NaN], [np.NaN]], [[0], [0]])
def _testLargeSort(self, dtype): def _testLargeSort(self, dtype):
b = 10 b = 10
n = 5000 n = 5000