Unify NthElement errors between graph & eager

PiperOrigin-RevId: 324157762
Change-Id: I7953bbed6ae37b0ac4999acf8e3c15a87974433b
This commit is contained in:
Gaurav Jain 2020-07-30 23:14:16 -07:00 committed by TensorFlower Gardener
parent 3f576ef723
commit 8472254a23
2 changed files with 68 additions and 54 deletions

View File

@ -40,23 +40,23 @@ class NthElementOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
// The second args is N, which must be a positive scalar. // The second args is N, which must be a positive scalar.
const auto& n_in = context->input(1); const auto& n_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()), OP_REQUIRES(
errors::InvalidArgument("N must be scalar, got shape ", context, TensorShapeUtils::IsScalar(n_in.shape()),
n_in.shape().DebugString())); errors::InvalidArgument("N must be scalar but has rank ", n_in.dims()));
int n = n_in.scalar<int32>()(); int n = n_in.scalar<int32>()();
OP_REQUIRES(context, n >= 0, OP_REQUIRES(context, n >= 0,
errors::InvalidArgument("Need n >= 0, got ", n)); errors::InvalidArgument("n must be non-negative but is ", n));
// The first args is input tensor, which must have 1 dimension at least. // The first args is input tensor, which must have 1 dimension at least.
const Tensor& input_in = context->input(0); const Tensor& input_in = context->input(0);
const int num_dims = input_in.dims(); const int num_dims = input_in.dims();
OP_REQUIRES(context, num_dims >= 1, OP_REQUIRES(context, num_dims >= 1,
errors::InvalidArgument("Input must be >= 1-D, got shape ", errors::InvalidArgument(
input_in.shape().DebugString())); "Input must be at least rank 1 but is rank ", num_dims));
// The last dimension of input tensor must be greater than N. // The last dimension of input tensor must be greater than N.
OP_REQUIRES( OP_REQUIRES(
context, input_in.dim_size(num_dims - 1) > n, context, input_in.dim_size(num_dims - 1) > n,
errors::InvalidArgument("Input must have at least n+1 columns")); errors::InvalidArgument("Input must have last dimension > n = ", n));
// std::nth_element only support the nth-smallest selection. // std::nth_element only support the nth-smallest selection.
if (reverse_) { if (reverse_) {

View File

@ -19,13 +19,15 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -112,65 +114,77 @@ class NthElementTest(test.TestCase):
self._testEnumerateN([10, 10, 10]) self._testEnumerateN([10, 10, 10])
self._testEnumerateN([10, 10, 10, 10]) self._testEnumerateN([10, 10, 10, 10])
@test_util.run_deprecated_v1
def testInvalidInput(self): def testInvalidInput(self):
with self.assertRaisesRegex(ValueError, "at least rank 1 but is rank 0"): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"at least rank 1 but is rank 0"):
nn_ops.nth_element(5, 0) nn_ops.nth_element(5, 0)
@test_util.run_deprecated_v1 # Test with placeholders
def testInvalidInputAtEval(self): with ops.Graph().as_default():
with self.session(use_gpu=False): with self.session(use_gpu=False):
v = array_ops.placeholder(dtype=dtypes.float32) v = array_ops.placeholder(dtype=dtypes.int32)
with self.assertRaisesOpError("Input must be >= 1-D"): with self.assertRaisesOpError("at least rank 1 but is rank 0"):
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0}) nn_ops.nth_element(v, 0).eval(feed_dict={v: 5})
@test_util.run_deprecated_v1
def testInvalidN(self): def testInvalidN(self):
with self.assertRaisesRegex(ValueError, "non-negative but is -1"): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"non-negative but is -1"):
nn_ops.nth_element([5], -1) nn_ops.nth_element([5], -1)
with self.assertRaisesRegex(ValueError, "scalar but has rank 1"): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"scalar but has rank 1"):
nn_ops.nth_element([5, 6, 3], [1]) nn_ops.nth_element([5, 6, 3], [1])
@test_util.run_deprecated_v1 # Test with placeholders
def testInvalidNAtEval(self): with ops.Graph().as_default():
inputs = [[0.1, 0.2], [0.3, 0.4]] with self.session(use_gpu=False):
with self.session(use_gpu=False): n = array_ops.placeholder(dtypes.int32)
n = array_ops.placeholder(dtypes.int32) values = nn_ops.nth_element([5], n)
values = nn_ops.nth_element(inputs, n) with self.assertRaisesOpError("non-negative but is -1"):
with self.assertRaisesOpError("Need n >= 0, got -7"): values.eval(feed_dict={n: -1})
values.eval(feed_dict={n: -7})
@test_util.run_deprecated_v1
def testNTooLarge(self): def testNTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]] inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.assertRaisesRegex(ValueError, "must have last dimension > n = 2"): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must have last dimension > n = 2"):
nn_ops.nth_element(inputs, 2) nn_ops.nth_element(inputs, 2)
@test_util.run_deprecated_v1 # Test with placeholders
def testNTooLargeAtEval(self): with ops.Graph().as_default():
inputs = [[0.1, 0.2], [0.3, 0.4]] with self.session(use_gpu=False):
with self.session(use_gpu=False): n = array_ops.placeholder(dtypes.int32)
n = array_ops.placeholder(dtypes.int32) values = nn_ops.nth_element(inputs, n)
values = nn_ops.nth_element(inputs, n) with self.assertRaisesOpError("must have last dimension > n = 2"):
with self.assertRaisesOpError(r"Input must have at least n\+1 columns"): values.eval(feed_dict={n: 2})
values.eval(feed_dict={n: 2})
@test_util.run_deprecated_v1
def testGradients(self): def testGradients(self):
with self.session(use_gpu=False) as sess: x = [
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5]) [2., -1., 1000., 3., 1000.],
values = nn_ops.nth_element(inputs, 3) [1., 5., 2., 4., 3.],
grad = sess.run( [2., 2., 2., 2., 2.],
gradients_impl.gradients( ]
values, inputs, grad_ys=[[-1., 2., 5.]]), grad_ys = [[-1., 2., 5.]]
feed_dict={inputs: [[2., -1., 1000., 3., 1000.], result = [
[1., 5., 2., 4., 3.], [0, 0, -0.5, 0, -0.5],
[2., 2., 2., 2., 2.], [0, 0, 0, 2, 0],
]}) [1, 1, 1, 1, 1],
self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5], ]
[0, 0, 0, 2, 0], if context.executing_eagerly():
[1, 1, 1, 1, 1], inputs = ops.convert_to_tensor(x)
]) with backprop.GradientTape() as tape:
tape.watch(inputs)
values = nn_ops.nth_element(inputs, 3)
grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys))
self.assertAllClose(grad[0], result)
# Test with tf.gradients
with ops.Graph().as_default():
with self.session(use_gpu=False) as sess:
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
values = nn_ops.nth_element(inputs, 3)
grad = sess.run(
gradients_impl.gradients(values, inputs, grad_ys=grad_ys),
feed_dict={inputs: x})
self.assertAllClose(grad[0], result)