Set the indices in TopK op for an edge case where k=1 and input is all NaN.
The returned indices aren't set in this case because input.maximum() returns -Inf and NaN != -Inf. Even if input.maximum() returns NaN, the indices still won't be set due to NaN != NaN. This behavior surprises users and is likely to cause index-out-of-range error in downstream operations. PiperOrigin-RevId: 257497013
This commit is contained in:
parent
943a1757d5
commit
1c94a5bed5
@ -123,12 +123,14 @@ struct TopKFunctor<CPUDevice, T> {
|
||||
input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
|
||||
// Get the indices of the maximum values.
|
||||
for (int r = 0; r < num_rows; ++r) {
|
||||
indices(r, 0) = 0;
|
||||
for (int c = 0; c < num_cols; ++c) {
|
||||
if (values(r, 0) == input(r, c)) {
|
||||
indices(r, 0) = c;
|
||||
break;
|
||||
}
|
||||
}
|
||||
values(r, 0) = input(r, indices(r, 0));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -108,6 +108,10 @@ class TopKTest(test.TestCase):
|
||||
values = -np.sort(-inputs)[:k]
|
||||
self._validateTopK(inputs, k, values, indices)
|
||||
|
||||
def testTop1AllNan(self):
|
||||
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]
|
||||
self._validateTopK(inputs, 1, [[np.NaN], [np.NaN]], [[0], [0]])
|
||||
|
||||
def _testLargeSort(self, dtype):
|
||||
b = 10
|
||||
n = 5000
|
||||
|
Loading…
Reference in New Issue
Block a user