Unify NthElement errors between graph & eager
PiperOrigin-RevId: 324157762 Change-Id: I7953bbed6ae37b0ac4999acf8e3c15a87974433b
This commit is contained in:
parent
3f576ef723
commit
8472254a23
tensorflow
@ -40,23 +40,23 @@ class NthElementOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// The second args is N, which must be a positive scalar.
|
||||
const auto& n_in = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()),
|
||||
errors::InvalidArgument("N must be scalar, got shape ",
|
||||
n_in.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsScalar(n_in.shape()),
|
||||
errors::InvalidArgument("N must be scalar but has rank ", n_in.dims()));
|
||||
int n = n_in.scalar<int32>()();
|
||||
OP_REQUIRES(context, n >= 0,
|
||||
errors::InvalidArgument("Need n >= 0, got ", n));
|
||||
errors::InvalidArgument("n must be non-negative but is ", n));
|
||||
|
||||
// The first args is input tensor, which must have 1 dimension at least.
|
||||
const Tensor& input_in = context->input(0);
|
||||
const int num_dims = input_in.dims();
|
||||
OP_REQUIRES(context, num_dims >= 1,
|
||||
errors::InvalidArgument("Input must be >= 1-D, got shape ",
|
||||
input_in.shape().DebugString()));
|
||||
errors::InvalidArgument(
|
||||
"Input must be at least rank 1 but is rank ", num_dims));
|
||||
// The last dimension of input tensor must be greater than N.
|
||||
OP_REQUIRES(
|
||||
context, input_in.dim_size(num_dims - 1) > n,
|
||||
errors::InvalidArgument("Input must have at least n+1 columns"));
|
||||
errors::InvalidArgument("Input must have last dimension > n = ", n));
|
||||
|
||||
// std::nth_element only support the nth-smallest selection.
|
||||
if (reverse_) {
|
||||
|
@ -19,13 +19,15 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -112,65 +114,77 @@ class NthElementTest(test.TestCase):
|
||||
self._testEnumerateN([10, 10, 10])
|
||||
self._testEnumerateN([10, 10, 10, 10])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidInput(self):
|
||||
with self.assertRaisesRegex(ValueError, "at least rank 1 but is rank 0"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"at least rank 1 but is rank 0"):
|
||||
nn_ops.nth_element(5, 0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidInputAtEval(self):
|
||||
with self.session(use_gpu=False):
|
||||
v = array_ops.placeholder(dtype=dtypes.float32)
|
||||
with self.assertRaisesOpError("Input must be >= 1-D"):
|
||||
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0})
|
||||
# Test with placeholders
|
||||
with ops.Graph().as_default():
|
||||
with self.session(use_gpu=False):
|
||||
v = array_ops.placeholder(dtype=dtypes.int32)
|
||||
with self.assertRaisesOpError("at least rank 1 but is rank 0"):
|
||||
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5})
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidN(self):
|
||||
with self.assertRaisesRegex(ValueError, "non-negative but is -1"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"non-negative but is -1"):
|
||||
nn_ops.nth_element([5], -1)
|
||||
with self.assertRaisesRegex(ValueError, "scalar but has rank 1"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"scalar but has rank 1"):
|
||||
nn_ops.nth_element([5, 6, 3], [1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidNAtEval(self):
|
||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||
with self.session(use_gpu=False):
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
values = nn_ops.nth_element(inputs, n)
|
||||
with self.assertRaisesOpError("Need n >= 0, got -7"):
|
||||
values.eval(feed_dict={n: -7})
|
||||
# Test with placeholders
|
||||
with ops.Graph().as_default():
|
||||
with self.session(use_gpu=False):
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
values = nn_ops.nth_element([5], n)
|
||||
with self.assertRaisesOpError("non-negative but is -1"):
|
||||
values.eval(feed_dict={n: -1})
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNTooLarge(self):
|
||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||
with self.assertRaisesRegex(ValueError, "must have last dimension > n = 2"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"must have last dimension > n = 2"):
|
||||
nn_ops.nth_element(inputs, 2)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNTooLargeAtEval(self):
|
||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||
with self.session(use_gpu=False):
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
values = nn_ops.nth_element(inputs, n)
|
||||
with self.assertRaisesOpError(r"Input must have at least n\+1 columns"):
|
||||
values.eval(feed_dict={n: 2})
|
||||
# Test with placeholders
|
||||
with ops.Graph().as_default():
|
||||
with self.session(use_gpu=False):
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
values = nn_ops.nth_element(inputs, n)
|
||||
with self.assertRaisesOpError("must have last dimension > n = 2"):
|
||||
values.eval(feed_dict={n: 2})
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradients(self):
|
||||
with self.session(use_gpu=False) as sess:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
|
||||
values = nn_ops.nth_element(inputs, 3)
|
||||
grad = sess.run(
|
||||
gradients_impl.gradients(
|
||||
values, inputs, grad_ys=[[-1., 2., 5.]]),
|
||||
feed_dict={inputs: [[2., -1., 1000., 3., 1000.],
|
||||
[1., 5., 2., 4., 3.],
|
||||
[2., 2., 2., 2., 2.],
|
||||
]})
|
||||
self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5],
|
||||
[0, 0, 0, 2, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
])
|
||||
x = [
|
||||
[2., -1., 1000., 3., 1000.],
|
||||
[1., 5., 2., 4., 3.],
|
||||
[2., 2., 2., 2., 2.],
|
||||
]
|
||||
grad_ys = [[-1., 2., 5.]]
|
||||
result = [
|
||||
[0, 0, -0.5, 0, -0.5],
|
||||
[0, 0, 0, 2, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
]
|
||||
if context.executing_eagerly():
|
||||
inputs = ops.convert_to_tensor(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(inputs)
|
||||
values = nn_ops.nth_element(inputs, 3)
|
||||
grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys))
|
||||
self.assertAllClose(grad[0], result)
|
||||
|
||||
# Test with tf.gradients
|
||||
with ops.Graph().as_default():
|
||||
with self.session(use_gpu=False) as sess:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
|
||||
values = nn_ops.nth_element(inputs, 3)
|
||||
grad = sess.run(
|
||||
gradients_impl.gradients(values, inputs, grad_ys=grad_ys),
|
||||
feed_dict={inputs: x})
|
||||
self.assertAllClose(grad[0], result)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user