Fix argument check tests to work in eager mode

PiperOrigin-RevId: 312170271
Change-Id: Ie7ffb52cf63559255b5463d651eb72b924a3c3bf
This commit is contained in:
Gaurav Jain 2020-05-18 15:51:05 -07:00 committed by TensorFlower Gardener
parent 94108993a3
commit 1acf6989bf
3 changed files with 67 additions and 68 deletions

View File

@ -43,9 +43,9 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename Tlen> template <typename Device, typename Tlen>
void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1); const Tensor& seq_lengths = context->input(1);
auto seq_lens_t = seq_lens.vec<Tlen>(); auto seq_lens_t = seq_lengths.vec<Tlen>();
std::vector<Tlen> seq_lens_vec(seq_lens_t.size()); std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
@ -56,15 +56,16 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, batch_dim != seq_dim, OP_REQUIRES(context, batch_dim != seq_dim,
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(), OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ", errors::InvalidArgument("seq_dim must be < input rank", " ( ",
seq_dim, " vs. ", input.dims(), ")")); seq_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, batch_dim < input.dims(), OP_REQUIRES(context, batch_dim < input.dims(),
errors::InvalidArgument("batch_dim must be < input.dims()", "( ", errors::InvalidArgument("batch_dim must be < input rank", " ( ",
batch_dim, " vs. ", input.dims(), ")")); batch_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), OP_REQUIRES(
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, context, seq_lengths.NumElements() == input.dim_size(batch_dim),
"), ", "(", seq_lens.NumElements(), errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
" vs. ", input.dim_size(batch_dim), ")")); "), ", "(", seq_lengths.NumElements(), " vs. ",
input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) { for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0, OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@ -77,21 +78,22 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1); const Tensor& seq_lengths = context->input(1);
OP_REQUIRES(context, batch_dim != seq_dim, OP_REQUIRES(context, batch_dim != seq_dim,
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(), OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ", errors::InvalidArgument("seq_dim must be < input rank", " ( ",
seq_dim, " vs. ", input.dims(), ")")); seq_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, batch_dim < input.dims(), OP_REQUIRES(context, batch_dim < input.dims(),
errors::InvalidArgument("batch_dim must be < input.dims()", "( ", errors::InvalidArgument("batch_dim must be < input rank", " ( ",
batch_dim, " vs. ", input.dims(), ")")); batch_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), OP_REQUIRES(
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, context, seq_lengths.NumElements() == input.dim_size(batch_dim),
"), ", "(", seq_lens.NumElements(), errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
" vs. ", input.dim_size(batch_dim), ")")); "), ", "(", seq_lengths.NumElements(), " vs. ",
input.dim_size(batch_dim), ")"));
} }
template <> template <>
@ -117,14 +119,14 @@ class ReverseSequenceOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1); const Tensor& seq_lengths = context->input(1);
// Preliminary validation of sizes. // Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()), OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()),
errors::InvalidArgument("seq_lens input must be 1-dim, not ", errors::InvalidArgument("seq_lengths must be 1-dim, not ",
seq_lens.dims())); seq_lengths.dims()));
auto seq_lens_t = seq_lens.vec<Tlen>(); auto seq_lens_t = seq_lengths.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_); CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
if (!context->status().ok()) return; if (!context->status().ok()) return;
@ -186,7 +188,7 @@ namespace functor {
void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \ void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
int32 batch_dim, int32 seq_dim, \ int32 batch_dim, int32 seq_dim, \
typename TTypes<Tlen>::ConstVec seq_lens, \ typename TTypes<Tlen>::ConstVec seq_lengths, \
typename TTypes<T, Dims>::Tensor output); \ typename TTypes<T, Dims>::Tensor output); \
extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>; extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;

View File

@ -19,10 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradient_checker
@ -135,56 +136,52 @@ class ReverseSequenceTest(test.TestCase):
print("ReverseSequence gradient error = %g" % err) print("ReverseSequence gradient error = %g" % err)
self.assertLess(err, 1e-8) self.assertLess(err, 1e-8)
@test_util.run_deprecated_v1
def testShapeFunctionEdgeCases(self): def testShapeFunctionEdgeCases(self):
t = array_ops.reverse_sequence( # Enter graph mode since we want to test partial shapes
array_ops.placeholder( with context.graph_mode():
dtypes.float32, shape=None), t = array_ops.reverse_sequence(
seq_lengths=array_ops.placeholder( array_ops.placeholder(dtypes.float32, shape=None),
dtypes.int64, shape=(32,)), seq_lengths=array_ops.placeholder(dtypes.int64, shape=(32,)),
batch_axis=0, batch_axis=0,
seq_axis=1) seq_axis=1)
self.assertIs(t.get_shape().ndims, None) self.assertIs(t.get_shape().ndims, None)
def testInvalidArguments(self):
# Batch size mismatched between input and seq_lengths. # Batch size mismatched between input and seq_lengths.
with self.assertRaises(ValueError): # seq_length too long
array_ops.reverse_sequence( with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
array_ops.placeholder( (r"Dimensions must be equal|"
dtypes.float32, shape=(32, 2, 3)), r"Length of seq_lengths != input.dims\(0\)")):
seq_lengths=array_ops.placeholder( array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2, 2], seq_axis=1)
dtypes.int64, shape=(33,)),
seq_axis=3) # seq_length too short
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
(r"Dimensions must be equal|"
r"Length of seq_lengths != input.dims\(0\)")):
array_ops.reverse_sequence([[1, 2], [3, 4]], [2], seq_axis=1)
# Invalid seq_length shape
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
("Shape must be rank 1 but is rank 2|"
"seq_lengths must be 1-dim")):
array_ops.reverse_sequence([[1, 2], [3, 4]], [[2, 2]], seq_axis=1)
# seq_axis out of bounds. # seq_axis out of bounds.
with self.assertRaisesRegexp(ValueError, "seq_dim must be < input rank"): with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
array_ops.reverse_sequence( "seq_dim must be < input rank"):
array_ops.placeholder( array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=2)
dtypes.float32, shape=(32, 2, 3)),
seq_lengths=array_ops.placeholder(
dtypes.int64, shape=(32,)),
seq_axis=3)
# batch_axis out of bounds. # batch_axis out of bounds.
with self.assertRaisesRegexp(ValueError, "batch_dim must be < input rank"): with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
array_ops.reverse_sequence( "batch_dim must be < input rank"):
array_ops.placeholder( array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2],
dtypes.float32, shape=(32, 2, 3)), seq_axis=1,
seq_lengths=array_ops.placeholder( batch_axis=3)
dtypes.int64, shape=(32,)),
seq_axis=0,
batch_axis=3)
with self.cached_session(): with self.assertRaisesRegexp((errors.OpError, errors.InvalidArgumentError),
inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3)) "batch_dim == seq_dim == 0"):
seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,)) output = array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=0)
output = array_ops.reverse_sequence( self.evaluate(output)
inputs, seq_lengths=seq_lengths,
seq_axis=0) # batch_axis default is 0
with self.assertRaisesOpError("batch_dim == seq_dim"):
output.eval(feed_dict={
inputs: np.random.rand(32, 2, 3),
seq_lengths: xrange(32)
})
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4473,8 +4473,8 @@ def reverse_sequence(input,
dimension `seq_axis`. dimension `seq_axis`.
The elements of `seq_lengths` must obey `seq_lengths[i] <= The elements of `seq_lengths` must obey `seq_lengths[i] <=
input.dims[seq_dim]`, and `seq_lengths` must be a vector of length input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
`input.dims[batch_dim]`. `input.dims[batch_axis]`.
The output slice `i` along dimension `batch_axis` is then given by The output slice `i` along dimension `batch_axis` is then given by
input slice `i`, with the first `seq_lengths[i]` slices along input slice `i`, with the first `seq_lengths[i]` slices along
@ -4496,8 +4496,8 @@ def reverse_sequence(input,
Args: Args:
input: A `Tensor`. The input to reverse. input: A `Tensor`. The input to reverse.
seq_lengths: A `Tensor`. Must be one of the following types: `int32`, seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
`int64`. 1-D with length `input.dims(batch_dim)` and `max(seq_lengths) <= `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
input.dims(seq_dim)` input.dims(seq_axis)`
seq_axis: An `int`. The dimension which is partially reversed. seq_axis: An `int`. The dimension which is partially reversed.
batch_axis: An optional `int`. Defaults to `0`. The dimension along which batch_axis: An optional `int`. Defaults to `0`. The dimension along which
reversal is performed. reversal is performed.