STT-tensorflow/tensorflow/python/kernel_tests/topk_op_test.py
A. Unique TensorFlower 1c94a5bed5 Set the indices in TopK op for an edge case where k=1 and input is all NaN.
The returned indices aren't set in this case because input.maximum() returns
-Inf and NaN != -Inf. Even if input.maximum() returns NaN, the indices still
won't be set due to NaN != NaN. This behavior surprises users and is likely to
cause index-out-of-range error in downstream operations.

PiperOrigin-RevId: 257497013
2019-07-10 16:03:15 -07:00

254 lines
9.3 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TopK op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import sys
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class TopKTest(test.TestCase):
def _validateTopK(self,
inputs,
k,
expected_values,
expected_indices,
sorted=True): # pylint: disable=redefined-builtin
np_expected_values = np.array(expected_values)
np_expected_indices = np.array(expected_indices)
with self.cached_session(use_gpu=True) as sess:
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
values, indices = self.evaluate([values_op, indices_op])
self.assertShapeEqual(np_expected_values, values_op)
self.assertShapeEqual(np_expected_indices, indices_op)
if sorted:
self.assertAllClose(np_expected_values, values)
# Do some special casing of equality of indices: if indices
# are not the same, but values are floating type, ensure that
# the values are within epsilon of each other.
if not np.issubdtype(np_expected_values.dtype, np.floating):
# Values are not floating point type; check indices exactly
self.assertAllEqual(np_expected_indices, indices)
else:
# Values are floating point; indices may be swapped for
# values near each other.
indices_not_equal = np_expected_indices != indices
if np.any(indices_not_equal):
values_unsure = values[indices_not_equal]
expected_values_unsure = expected_values[indices_not_equal]
self.assertAllClose(expected_values_unsure, values_unsure)
else:
np_inputs = np.array(inputs)
# Check that the indices are valid.
for result_index, src_index in np.ndenumerate(indices):
value = values[result_index]
expected_value = np_inputs[result_index[0], src_index]
np.testing.assert_almost_equal(value, expected_value)
# Check that if two elements are equal, the lower-index element appears
# first.
shape = values.shape
for batch_index in range(shape[0]):
for index in range(shape[1] - 1):
if np.isclose(values[batch_index, index],
values[batch_index, index + 1]):
self.assertLess(indices[batch_index, index],
indices[batch_index, index + 1])
# Now check the results, ignoring order.
self.assertAllEqual(np.sort(np_expected_indices), np.sort(indices))
self.assertAllClose(np.sort(np_expected_values), np.sort(values))
def testTop1(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
self._validateTopK(inputs, 1, [[0.4], [0.3]], [[3], [1]])
def testTop2(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.4, 0.2]]
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
def testTop3(self):
k = 5
inputs = np.random.permutation(np.linspace(0, 100, 6140, dtype=np.float64))
indices = np.argsort(-inputs)[:k]
values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices)
def testTop1AllNan(self):
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]
self._validateTopK(inputs, 1, [[np.NaN], [np.NaN]], [[0], [0]])
def _testLargeSort(self, dtype):
b = 10
n = 5000
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)
values = -np.sort(-inputs, axis=1)
self._validateTopK(inputs, n, values, indices)
def testLargeSort(self):
self._testLargeSort(np.float32)
self._testLargeSort(np.float16)
def _testLargeTopK(self, dtype):
b = 10
n = 5000
k = n - 1
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)[:, :k]
values = -np.sort(-inputs, axis=1)[:, :k]
self._validateTopK(inputs, k, values, indices)
def testLargeTopK(self):
self._testLargeTopK(np.float32)
self._testLargeTopK(np.float16)
def _testMediumTopK(self, dtype):
b = 5
n = 500
k = 50
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)[:, :k]
values = -np.sort(-inputs, axis=1)[:, :k]
self._validateTopK(inputs, k, values, indices)
def testMediumTopK(self):
self._testMediumTopK(np.float32)
self._testMediumTopK(np.float16)
def testStableSort(self):
b = 5
n = 500
for k in [1, 5, 50, 500]:
# Lots of repeated integers taking values in [0, 3]
inputs = np.random.permutation(
np.linspace(0, 3, b * n, dtype=np.int32)).reshape(b, n)
# Use mergesort, a stable sort, to get the indices.
indices = np.argsort(-inputs, axis=1, kind="mergesort")[:, :k]
values = -np.sort(-inputs, axis=1)[:, :k]
self._validateTopK(inputs, k, values, indices)
def testTopAll(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
self._validateTopK(inputs, 4, [[0.4, 0.3, 0.2, 0.1], [0.3, 0.3, 0.2, 0.1]],
[[3, 1, 2, 0], [1, 2, 3, 0]])
def testTop3Unsorted(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.4, 0.3, 0.2]]
self._validateTopK(
inputs,
3, [[0.2, 0.3, 0.4], [0.2, 0.4, 0.3]], [[2, 1, 3], [3, 1, 2]],
sorted=False)
def testTop3Vector(self):
inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6]
self._validateTopK(inputs, 3, [19, 18, 17], [11, 3, 7])
def testTensorK(self):
inputs = [3, 6, 15, 18, 6, 12, 1, 17, 3, 0, 4, 19, 1, 6]
k = constant_op.constant(3)
self._validateTopK(inputs, k, [19, 18, 17], [11, 3, 7])
def testTop3ZeroRows(self):
inputs = np.zeros([0, 10], dtype=np.float32)
self._validateTopK(inputs, 3, np.zeros([0, 3], dtype=np.float32),
np.zeros([0, 3], dtype=np.int32))
@test_util.run_deprecated_v1
def testKNegative(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.session(use_gpu=True):
k = array_ops.placeholder(dtypes.int32)
values, _ = nn_ops.top_k(inputs, k)
with self.assertRaisesOpError("Need k >= 0, got -7"):
values.eval(feed_dict={k: -7})
@test_util.run_deprecated_v1
def testKTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.assertRaisesRegexp(ValueError,
r"must have last dimension >= k = 4"):
nn_ops.top_k(inputs, 4)
@test_util.run_deprecated_v1
def testTopKGradients(self):
with self.session(use_gpu=True) as sess:
inputs = array_ops.placeholder(dtypes.float32, shape=[2, 5])
values, _ = nn_ops.top_k(inputs, 3)
grad = sess.run(
gradients_impl.gradients(
values, inputs, grad_ys=[[[1., 2., 3.], [4., 5., 6.]]]),
feed_dict={inputs: [[2., -1., 1000., 3., 4.],
[1., 5., 2., 4., 3.]]})[0]
self.assertEqual(
grad.tolist(), [[0., 0., 1., 3., 2.], [0., 4., 0., 5., 6.]])
class TopKBenchmark(test.Benchmark):
def benchmarkTopK(self):
for (m, n, p, use_gpu) in itertools.product(
[128],
[10, 100, 1000, 10000, 100000],
[0.001, 0.01, 0.5, 0.99, 1.0],
[False, True]):
k = int(p * n)
if k == 0:
continue
name = "m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
device = "/%s:0" % ("gpu" if use_gpu else "cpu")
with ops.Graph().as_default():
with ops.device(device):
x = random_ops.random_uniform((m, n))
v = resource_variable_ops.ResourceVariable(x)
op = nn_ops.top_k(v, k)
with session.Session() as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
throughput = gb_processed_input / r["wall_time"]
print("Benchmark: %s \t wall_time: %0.03g s \t "
"Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
sys.stdout.flush()
if __name__ == "__main__":
test.main()