Unify NthElement errors between graph & eager
PiperOrigin-RevId: 324157762 Change-Id: I7953bbed6ae37b0ac4999acf8e3c15a87974433b
This commit is contained in:
parent
3f576ef723
commit
8472254a23
@ -40,23 +40,23 @@ class NthElementOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// The second args is N, which must be a positive scalar.
|
// The second args is N, which must be a positive scalar.
|
||||||
const auto& n_in = context->input(1);
|
const auto& n_in = context->input(1);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("N must be scalar, got shape ",
|
context, TensorShapeUtils::IsScalar(n_in.shape()),
|
||||||
n_in.shape().DebugString()));
|
errors::InvalidArgument("N must be scalar but has rank ", n_in.dims()));
|
||||||
int n = n_in.scalar<int32>()();
|
int n = n_in.scalar<int32>()();
|
||||||
OP_REQUIRES(context, n >= 0,
|
OP_REQUIRES(context, n >= 0,
|
||||||
errors::InvalidArgument("Need n >= 0, got ", n));
|
errors::InvalidArgument("n must be non-negative but is ", n));
|
||||||
|
|
||||||
// The first args is input tensor, which must have 1 dimension at least.
|
// The first args is input tensor, which must have 1 dimension at least.
|
||||||
const Tensor& input_in = context->input(0);
|
const Tensor& input_in = context->input(0);
|
||||||
const int num_dims = input_in.dims();
|
const int num_dims = input_in.dims();
|
||||||
OP_REQUIRES(context, num_dims >= 1,
|
OP_REQUIRES(context, num_dims >= 1,
|
||||||
errors::InvalidArgument("Input must be >= 1-D, got shape ",
|
errors::InvalidArgument(
|
||||||
input_in.shape().DebugString()));
|
"Input must be at least rank 1 but is rank ", num_dims));
|
||||||
// The last dimension of input tensor must be greater than N.
|
// The last dimension of input tensor must be greater than N.
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input_in.dim_size(num_dims - 1) > n,
|
context, input_in.dim_size(num_dims - 1) > n,
|
||||||
errors::InvalidArgument("Input must have at least n+1 columns"));
|
errors::InvalidArgument("Input must have last dimension > n = ", n));
|
||||||
|
|
||||||
// std::nth_element only support the nth-smallest selection.
|
// std::nth_element only support the nth-smallest selection.
|
||||||
if (reverse_) {
|
if (reverse_) {
|
||||||
|
@ -19,13 +19,15 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.ops import nn_ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -112,65 +114,77 @@ class NthElementTest(test.TestCase):
|
|||||||
self._testEnumerateN([10, 10, 10])
|
self._testEnumerateN([10, 10, 10])
|
||||||
self._testEnumerateN([10, 10, 10, 10])
|
self._testEnumerateN([10, 10, 10, 10])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInvalidInput(self):
|
def testInvalidInput(self):
|
||||||
with self.assertRaisesRegex(ValueError, "at least rank 1 but is rank 0"):
|
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||||
|
"at least rank 1 but is rank 0"):
|
||||||
nn_ops.nth_element(5, 0)
|
nn_ops.nth_element(5, 0)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
# Test with placeholders
|
||||||
def testInvalidInputAtEval(self):
|
with ops.Graph().as_default():
|
||||||
with self.session(use_gpu=False):
|
with self.session(use_gpu=False):
|
||||||
v = array_ops.placeholder(dtype=dtypes.float32)
|
v = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
with self.assertRaisesOpError("Input must be >= 1-D"):
|
with self.assertRaisesOpError("at least rank 1 but is rank 0"):
|
||||||
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0})
|
nn_ops.nth_element(v, 0).eval(feed_dict={v: 5})
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInvalidN(self):
|
def testInvalidN(self):
|
||||||
with self.assertRaisesRegex(ValueError, "non-negative but is -1"):
|
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||||
|
"non-negative but is -1"):
|
||||||
nn_ops.nth_element([5], -1)
|
nn_ops.nth_element([5], -1)
|
||||||
with self.assertRaisesRegex(ValueError, "scalar but has rank 1"):
|
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||||
|
"scalar but has rank 1"):
|
||||||
nn_ops.nth_element([5, 6, 3], [1])
|
nn_ops.nth_element([5, 6, 3], [1])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
# Test with placeholders
|
||||||
def testInvalidNAtEval(self):
|
with ops.Graph().as_default():
|
||||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
with self.session(use_gpu=False):
|
||||||
with self.session(use_gpu=False):
|
n = array_ops.placeholder(dtypes.int32)
|
||||||
n = array_ops.placeholder(dtypes.int32)
|
values = nn_ops.nth_element([5], n)
|
||||||
values = nn_ops.nth_element(inputs, n)
|
with self.assertRaisesOpError("non-negative but is -1"):
|
||||||
with self.assertRaisesOpError("Need n >= 0, got -7"):
|
values.eval(feed_dict={n: -1})
|
||||||
values.eval(feed_dict={n: -7})
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testNTooLarge(self):
|
def testNTooLarge(self):
|
||||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||||
with self.assertRaisesRegex(ValueError, "must have last dimension > n = 2"):
|
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||||
|
"must have last dimension > n = 2"):
|
||||||
nn_ops.nth_element(inputs, 2)
|
nn_ops.nth_element(inputs, 2)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
# Test with placeholders
|
||||||
def testNTooLargeAtEval(self):
|
with ops.Graph().as_default():
|
||||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
with self.session(use_gpu=False):
|
||||||
with self.session(use_gpu=False):
|
n = array_ops.placeholder(dtypes.int32)
|
||||||
n = array_ops.placeholder(dtypes.int32)
|
values = nn_ops.nth_element(inputs, n)
|
||||||
values = nn_ops.nth_element(inputs, n)
|
with self.assertRaisesOpError("must have last dimension > n = 2"):
|
||||||
with self.assertRaisesOpError(r"Input must have at least n\+1 columns"):
|
values.eval(feed_dict={n: 2})
|
||||||
values.eval(feed_dict={n: 2})
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradients(self):
|
def testGradients(self):
|
||||||
with self.session(use_gpu=False) as sess:
|
x = [
|
||||||
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
|
[2., -1., 1000., 3., 1000.],
|
||||||
values = nn_ops.nth_element(inputs, 3)
|
[1., 5., 2., 4., 3.],
|
||||||
grad = sess.run(
|
[2., 2., 2., 2., 2.],
|
||||||
gradients_impl.gradients(
|
]
|
||||||
values, inputs, grad_ys=[[-1., 2., 5.]]),
|
grad_ys = [[-1., 2., 5.]]
|
||||||
feed_dict={inputs: [[2., -1., 1000., 3., 1000.],
|
result = [
|
||||||
[1., 5., 2., 4., 3.],
|
[0, 0, -0.5, 0, -0.5],
|
||||||
[2., 2., 2., 2., 2.],
|
[0, 0, 0, 2, 0],
|
||||||
]})
|
[1, 1, 1, 1, 1],
|
||||||
self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5],
|
]
|
||||||
[0, 0, 0, 2, 0],
|
if context.executing_eagerly():
|
||||||
[1, 1, 1, 1, 1],
|
inputs = ops.convert_to_tensor(x)
|
||||||
])
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(inputs)
|
||||||
|
values = nn_ops.nth_element(inputs, 3)
|
||||||
|
grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys))
|
||||||
|
self.assertAllClose(grad[0], result)
|
||||||
|
|
||||||
|
# Test with tf.gradients
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.session(use_gpu=False) as sess:
|
||||||
|
inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
|
||||||
|
values = nn_ops.nth_element(inputs, 3)
|
||||||
|
grad = sess.run(
|
||||||
|
gradients_impl.gradients(values, inputs, grad_ys=grad_ys),
|
||||||
|
feed_dict={inputs: x})
|
||||||
|
self.assertAllClose(grad[0], result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user