Unify NthElement errors between graph & eager

PiperOrigin-RevId: 324157762
Change-Id: I7953bbed6ae37b0ac4999acf8e3c15a87974433b
This commit is contained in:
Gaurav Jain 2020-07-30 23:14:16 -07:00 committed by TensorFlower Gardener
parent 3f576ef723
commit 8472254a23
2 changed files with 68 additions and 54 deletions
tensorflow
core/kernels
python/kernel_tests

View File

@ -40,23 +40,23 @@ class NthElementOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The second args is N, which must be a positive scalar.
const auto& n_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()),
errors::InvalidArgument("N must be scalar, got shape ",
n_in.shape().DebugString()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(n_in.shape()),
errors::InvalidArgument("N must be scalar but has rank ", n_in.dims()));
int n = n_in.scalar<int32>()();
OP_REQUIRES(context, n >= 0,
errors::InvalidArgument("Need n >= 0, got ", n));
errors::InvalidArgument("n must be non-negative but is ", n));
// The first args is input tensor, which must have 1 dimension at least.
const Tensor& input_in = context->input(0);
const int num_dims = input_in.dims();
OP_REQUIRES(context, num_dims >= 1,
errors::InvalidArgument("Input must be >= 1-D, got shape ",
input_in.shape().DebugString()));
errors::InvalidArgument(
"Input must be at least rank 1 but is rank ", num_dims));
// The last dimension of input tensor must be greater than N.
OP_REQUIRES(
context, input_in.dim_size(num_dims - 1) > n,
errors::InvalidArgument("Input must have at least n+1 columns"));
errors::InvalidArgument("Input must have last dimension > n = ", n));
// std::nth_element only support the nth-smallest selection.
if (reverse_) {

View File

@ -19,13 +19,15 @@ from __future__ import print_function
import numpy as np
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@ -112,65 +114,77 @@ class NthElementTest(test.TestCase):
self._testEnumerateN([10, 10, 10])
self._testEnumerateN([10, 10, 10, 10])
@test_util.run_deprecated_v1
def testInvalidInput(self):
with self.assertRaisesRegex(ValueError, "at least rank 1 but is rank 0"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"at least rank 1 but is rank 0"):
nn_ops.nth_element(5, 0)
@test_util.run_deprecated_v1
def testInvalidInputAtEval(self):
with self.session(use_gpu=False):
v = array_ops.placeholder(dtype=dtypes.float32)
with self.assertRaisesOpError("Input must be >= 1-D"):
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0})
# Test with placeholders
with ops.Graph().as_default():
with self.session(use_gpu=False):
v = array_ops.placeholder(dtype=dtypes.int32)
with self.assertRaisesOpError("at least rank 1 but is rank 0"):
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5})
@test_util.run_deprecated_v1
def testInvalidN(self):
with self.assertRaisesRegex(ValueError, "non-negative but is -1"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"non-negative but is -1"):
nn_ops.nth_element([5], -1)
with self.assertRaisesRegex(ValueError, "scalar but has rank 1"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"scalar but has rank 1"):
nn_ops.nth_element([5, 6, 3], [1])
@test_util.run_deprecated_v1
def testInvalidNAtEval(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.session(use_gpu=False):
n = array_ops.placeholder(dtypes.int32)
values = nn_ops.nth_element(inputs, n)
with self.assertRaisesOpError("Need n >= 0, got -7"):
values.eval(feed_dict={n: -7})
# Test with placeholders
with ops.Graph().as_default():
with self.session(use_gpu=False):
n = array_ops.placeholder(dtypes.int32)
values = nn_ops.nth_element([5], n)
with self.assertRaisesOpError("non-negative but is -1"):
values.eval(feed_dict={n: -1})
@test_util.run_deprecated_v1
def testNTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.assertRaisesRegex(ValueError, "must have last dimension > n = 2"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must have last dimension > n = 2"):
nn_ops.nth_element(inputs, 2)
@test_util.run_deprecated_v1
def testNTooLargeAtEval(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.session(use_gpu=False):
n = array_ops.placeholder(dtypes.int32)
values = nn_ops.nth_element(inputs, n)
with self.assertRaisesOpError(r"Input must have at least n\+1 columns"):
values.eval(feed_dict={n: 2})
# Test with placeholders
with ops.Graph().as_default():
with self.session(use_gpu=False):
n = array_ops.placeholder(dtypes.int32)
values = nn_ops.nth_element(inputs, n)
with self.assertRaisesOpError("must have last dimension > n = 2"):
values.eval(feed_dict={n: 2})
@test_util.run_deprecated_v1
def testGradients(self):
with self.session(use_gpu=False) as sess:
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
values = nn_ops.nth_element(inputs, 3)
grad = sess.run(
gradients_impl.gradients(
values, inputs, grad_ys=[[-1., 2., 5.]]),
feed_dict={inputs: [[2., -1., 1000., 3., 1000.],
[1., 5., 2., 4., 3.],
[2., 2., 2., 2., 2.],
]})
self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5],
[0, 0, 0, 2, 0],
[1, 1, 1, 1, 1],
])
x = [
[2., -1., 1000., 3., 1000.],
[1., 5., 2., 4., 3.],
[2., 2., 2., 2., 2.],
]
grad_ys = [[-1., 2., 5.]]
result = [
[0, 0, -0.5, 0, -0.5],
[0, 0, 0, 2, 0],
[1, 1, 1, 1, 1],
]
if context.executing_eagerly():
inputs = ops.convert_to_tensor(x)
with backprop.GradientTape() as tape:
tape.watch(inputs)
values = nn_ops.nth_element(inputs, 3)
grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys))
self.assertAllClose(grad[0], result)
# Test with tf.gradients
with ops.Graph().as_default():
with self.session(use_gpu=False) as sess:
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
values = nn_ops.nth_element(inputs, 3)
grad = sess.run(
gradients_impl.gradients(values, inputs, grad_ys=grad_ys),
feed_dict={inputs: x})
self.assertAllClose(grad[0], result)