Set the indices in TopK op for an edge case where k=1 and input is all NaN.
The returned indices aren't set in this case because input.maximum() returns -Inf and NaN != -Inf. Even if input.maximum() returns NaN, the indices still won't be set due to NaN != NaN. This behavior surprises users and is likely to cause index-out-of-range error in downstream operations. PiperOrigin-RevId: 257497013
This commit is contained in:
parent
943a1757d5
commit
1c94a5bed5
@ -123,12 +123,14 @@ struct TopKFunctor<CPUDevice, T> {
|
|||||||
input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
|
input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
|
||||||
// Get the indices of the maximum values.
|
// Get the indices of the maximum values.
|
||||||
for (int r = 0; r < num_rows; ++r) {
|
for (int r = 0; r < num_rows; ++r) {
|
||||||
|
indices(r, 0) = 0;
|
||||||
for (int c = 0; c < num_cols; ++c) {
|
for (int c = 0; c < num_cols; ++c) {
|
||||||
if (values(r, 0) == input(r, c)) {
|
if (values(r, 0) == input(r, c)) {
|
||||||
indices(r, 0) = c;
|
indices(r, 0) = c;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
values(r, 0) = input(r, indices(r, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -108,6 +108,10 @@ class TopKTest(test.TestCase):
|
|||||||
values = -np.sort(-inputs)[:k]
|
values = -np.sort(-inputs)[:k]
|
||||||
self._validateTopK(inputs, k, values, indices)
|
self._validateTopK(inputs, k, values, indices)
|
||||||
|
|
||||||
|
def testTop1AllNan(self):
|
||||||
|
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]
|
||||||
|
self._validateTopK(inputs, 1, [[np.NaN], [np.NaN]], [[0], [0]])
|
||||||
|
|
||||||
def _testLargeSort(self, dtype):
|
def _testLargeSort(self, dtype):
|
||||||
b = 10
|
b = 10
|
||||||
n = 5000
|
n = 5000
|
||||||
|
Loading…
Reference in New Issue
Block a user