diff --git a/tensorflow/core/kernels/nth_element_op.cc b/tensorflow/core/kernels/nth_element_op.cc index 0e43cc19aae..dced32ef7df 100644 --- a/tensorflow/core/kernels/nth_element_op.cc +++ b/tensorflow/core/kernels/nth_element_op.cc @@ -40,23 +40,23 @@ class NthElementOp : public OpKernel { void Compute(OpKernelContext* context) override { // The second args is N, which must be a positive scalar. const auto& n_in = context->input(1); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()), - errors::InvalidArgument("N must be scalar, got shape ", - n_in.shape().DebugString())); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(n_in.shape()), + errors::InvalidArgument("N must be scalar but has rank ", n_in.dims())); int n = n_in.scalar()(); OP_REQUIRES(context, n >= 0, - errors::InvalidArgument("Need n >= 0, got ", n)); + errors::InvalidArgument("n must be non-negative but is ", n)); // The first args is input tensor, which must have 1 dimension at least. const Tensor& input_in = context->input(0); const int num_dims = input_in.dims(); OP_REQUIRES(context, num_dims >= 1, - errors::InvalidArgument("Input must be >= 1-D, got shape ", - input_in.shape().DebugString())); + errors::InvalidArgument( + "Input must be at least rank 1 but is rank ", num_dims)); // The last dimension of input tensor must be greater than N. OP_REQUIRES( context, input_in.dim_size(num_dims - 1) > n, - errors::InvalidArgument("Input must have at least n+1 columns")); + errors::InvalidArgument("Input must have last dimension > n = ", n)); // std::nth_element only support the nth-smallest selection. if (reverse_) { diff --git a/tensorflow/python/kernel_tests/nth_element_op_test.py b/tensorflow/python/kernel_tests/nth_element_op_test.py index d8b9adb8731..488dfc0db66 100644 --- a/tensorflow/python/kernel_tests/nth_element_op_test.py +++ b/tensorflow/python/kernel_tests/nth_element_op_test.py @@ -19,13 +19,15 @@ from __future__ import print_function import numpy as np -import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import nn_ops +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -112,65 +114,77 @@ class NthElementTest(test.TestCase): self._testEnumerateN([10, 10, 10]) self._testEnumerateN([10, 10, 10, 10]) - @test_util.run_deprecated_v1 def testInvalidInput(self): - with self.assertRaisesRegex(ValueError, "at least rank 1 but is rank 0"): + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "at least rank 1 but is rank 0"): nn_ops.nth_element(5, 0) - @test_util.run_deprecated_v1 - def testInvalidInputAtEval(self): - with self.session(use_gpu=False): - v = array_ops.placeholder(dtype=dtypes.float32) - with self.assertRaisesOpError("Input must be >= 1-D"): - nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0}) + # Test with placeholders + with ops.Graph().as_default(): + with self.session(use_gpu=False): + v = array_ops.placeholder(dtype=dtypes.int32) + with self.assertRaisesOpError("at least rank 1 but is rank 0"): + nn_ops.nth_element(v, 0).eval(feed_dict={v: 5}) - @test_util.run_deprecated_v1 def testInvalidN(self): - with self.assertRaisesRegex(ValueError, "non-negative but is -1"): + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "non-negative but is -1"): nn_ops.nth_element([5], -1) - with self.assertRaisesRegex(ValueError, "scalar but has rank 1"): + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "scalar but has rank 1"): nn_ops.nth_element([5, 6, 3], [1]) - @test_util.run_deprecated_v1 - def testInvalidNAtEval(self): - inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.session(use_gpu=False): - n = array_ops.placeholder(dtypes.int32) - values = nn_ops.nth_element(inputs, n) - with self.assertRaisesOpError("Need n >= 0, got -7"): - values.eval(feed_dict={n: -7}) + # Test with placeholders + with ops.Graph().as_default(): + with self.session(use_gpu=False): + n = array_ops.placeholder(dtypes.int32) + values = nn_ops.nth_element([5], n) + with self.assertRaisesOpError("non-negative but is -1"): + values.eval(feed_dict={n: -1}) - @test_util.run_deprecated_v1 def testNTooLarge(self): inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.assertRaisesRegex(ValueError, "must have last dimension > n = 2"): + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "must have last dimension > n = 2"): nn_ops.nth_element(inputs, 2) - @test_util.run_deprecated_v1 - def testNTooLargeAtEval(self): - inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.session(use_gpu=False): - n = array_ops.placeholder(dtypes.int32) - values = nn_ops.nth_element(inputs, n) - with self.assertRaisesOpError(r"Input must have at least n\+1 columns"): - values.eval(feed_dict={n: 2}) + # Test with placeholders + with ops.Graph().as_default(): + with self.session(use_gpu=False): + n = array_ops.placeholder(dtypes.int32) + values = nn_ops.nth_element(inputs, n) + with self.assertRaisesOpError("must have last dimension > n = 2"): + values.eval(feed_dict={n: 2}) - @test_util.run_deprecated_v1 def testGradients(self): - with self.session(use_gpu=False) as sess: - inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5]) - values = nn_ops.nth_element(inputs, 3) - grad = sess.run( - gradients_impl.gradients( - values, inputs, grad_ys=[[-1., 2., 5.]]), - feed_dict={inputs: [[2., -1., 1000., 3., 1000.], - [1., 5., 2., 4., 3.], - [2., 2., 2., 2., 2.], - ]}) - self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5], - [0, 0, 0, 2, 0], - [1, 1, 1, 1, 1], - ]) + x = [ + [2., -1., 1000., 3., 1000.], + [1., 5., 2., 4., 3.], + [2., 2., 2., 2., 2.], + ] + grad_ys = [[-1., 2., 5.]] + result = [ + [0, 0, -0.5, 0, -0.5], + [0, 0, 0, 2, 0], + [1, 1, 1, 1, 1], + ] + if context.executing_eagerly(): + inputs = ops.convert_to_tensor(x) + with backprop.GradientTape() as tape: + tape.watch(inputs) + values = nn_ops.nth_element(inputs, 3) + grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys)) + self.assertAllClose(grad[0], result) + + # Test with tf.gradients + with ops.Graph().as_default(): + with self.session(use_gpu=False) as sess: + inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5]) + values = nn_ops.nth_element(inputs, 3) + grad = sess.run( + gradients_impl.gradients(values, inputs, grad_ys=grad_ys), + feed_dict={inputs: x}) + self.assertAllClose(grad[0], result)