Fix argument check tests to work in eager mode
PiperOrigin-RevId: 312170271 Change-Id: Ie7ffb52cf63559255b5463d651eb72b924a3c3bf
This commit is contained in:
parent
94108993a3
commit
1acf6989bf
@ -43,9 +43,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template <typename Device, typename Tlen>
|
||||
void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& seq_lens = context->input(1);
|
||||
const Tensor& seq_lengths = context->input(1);
|
||||
|
||||
auto seq_lens_t = seq_lens.vec<Tlen>();
|
||||
auto seq_lens_t = seq_lengths.vec<Tlen>();
|
||||
|
||||
std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
|
||||
|
||||
@ -56,15 +56,16 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||
OP_REQUIRES(context, batch_dim != seq_dim,
|
||||
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
||||
OP_REQUIRES(context, seq_dim < input.dims(),
|
||||
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
||||
errors::InvalidArgument("seq_dim must be < input rank", " ( ",
|
||||
seq_dim, " vs. ", input.dims(), ")"));
|
||||
OP_REQUIRES(context, batch_dim < input.dims(),
|
||||
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
||||
errors::InvalidArgument("batch_dim must be < input rank", " ( ",
|
||||
batch_dim, " vs. ", input.dims(), ")"));
|
||||
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
|
||||
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
|
||||
"), ", "(", seq_lens.NumElements(),
|
||||
" vs. ", input.dim_size(batch_dim), ")"));
|
||||
OP_REQUIRES(
|
||||
context, seq_lengths.NumElements() == input.dim_size(batch_dim),
|
||||
errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
|
||||
"), ", "(", seq_lengths.NumElements(), " vs. ",
|
||||
input.dim_size(batch_dim), ")"));
|
||||
|
||||
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
|
||||
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
|
||||
@ -77,21 +78,22 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||
|
||||
void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& seq_lens = context->input(1);
|
||||
const Tensor& seq_lengths = context->input(1);
|
||||
|
||||
OP_REQUIRES(context, batch_dim != seq_dim,
|
||||
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
||||
OP_REQUIRES(context, seq_dim < input.dims(),
|
||||
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
||||
errors::InvalidArgument("seq_dim must be < input rank", " ( ",
|
||||
seq_dim, " vs. ", input.dims(), ")"));
|
||||
OP_REQUIRES(context, batch_dim < input.dims(),
|
||||
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
||||
errors::InvalidArgument("batch_dim must be < input rank", " ( ",
|
||||
batch_dim, " vs. ", input.dims(), ")"));
|
||||
|
||||
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
|
||||
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
|
||||
"), ", "(", seq_lens.NumElements(),
|
||||
" vs. ", input.dim_size(batch_dim), ")"));
|
||||
OP_REQUIRES(
|
||||
context, seq_lengths.NumElements() == input.dim_size(batch_dim),
|
||||
errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
|
||||
"), ", "(", seq_lengths.NumElements(), " vs. ",
|
||||
input.dim_size(batch_dim), ")"));
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -117,14 +119,14 @@ class ReverseSequenceOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& seq_lens = context->input(1);
|
||||
const Tensor& seq_lengths = context->input(1);
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
|
||||
errors::InvalidArgument("seq_lens input must be 1-dim, not ",
|
||||
seq_lens.dims()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()),
|
||||
errors::InvalidArgument("seq_lengths must be 1-dim, not ",
|
||||
seq_lengths.dims()));
|
||||
|
||||
auto seq_lens_t = seq_lens.vec<Tlen>();
|
||||
auto seq_lens_t = seq_lengths.vec<Tlen>();
|
||||
|
||||
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
|
||||
if (!context->status().ok()) return;
|
||||
@ -186,7 +188,7 @@ namespace functor {
|
||||
void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
|
||||
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
|
||||
int32 batch_dim, int32 seq_dim, \
|
||||
typename TTypes<Tlen>::ConstVec seq_lens, \
|
||||
typename TTypes<Tlen>::ConstVec seq_lengths, \
|
||||
typename TTypes<T, Dims>::Tensor output); \
|
||||
extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
|
||||
|
||||
|
@ -19,10 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -135,56 +136,52 @@ class ReverseSequenceTest(test.TestCase):
|
||||
print("ReverseSequence gradient error = %g" % err)
|
||||
self.assertLess(err, 1e-8)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeFunctionEdgeCases(self):
|
||||
t = array_ops.reverse_sequence(
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=None),
|
||||
seq_lengths=array_ops.placeholder(
|
||||
dtypes.int64, shape=(32,)),
|
||||
batch_axis=0,
|
||||
seq_axis=1)
|
||||
self.assertIs(t.get_shape().ndims, None)
|
||||
# Enter graph mode since we want to test partial shapes
|
||||
with context.graph_mode():
|
||||
t = array_ops.reverse_sequence(
|
||||
array_ops.placeholder(dtypes.float32, shape=None),
|
||||
seq_lengths=array_ops.placeholder(dtypes.int64, shape=(32,)),
|
||||
batch_axis=0,
|
||||
seq_axis=1)
|
||||
self.assertIs(t.get_shape().ndims, None)
|
||||
|
||||
def testInvalidArguments(self):
|
||||
# Batch size mismatched between input and seq_lengths.
|
||||
with self.assertRaises(ValueError):
|
||||
array_ops.reverse_sequence(
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(32, 2, 3)),
|
||||
seq_lengths=array_ops.placeholder(
|
||||
dtypes.int64, shape=(33,)),
|
||||
seq_axis=3)
|
||||
# seq_length too long
|
||||
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||
(r"Dimensions must be equal|"
|
||||
r"Length of seq_lengths != input.dims\(0\)")):
|
||||
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2, 2], seq_axis=1)
|
||||
|
||||
# seq_length too short
|
||||
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||
(r"Dimensions must be equal|"
|
||||
r"Length of seq_lengths != input.dims\(0\)")):
|
||||
array_ops.reverse_sequence([[1, 2], [3, 4]], [2], seq_axis=1)
|
||||
|
||||
# Invalid seq_length shape
|
||||
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||
("Shape must be rank 1 but is rank 2|"
|
||||
"seq_lengths must be 1-dim")):
|
||||
array_ops.reverse_sequence([[1, 2], [3, 4]], [[2, 2]], seq_axis=1)
|
||||
|
||||
# seq_axis out of bounds.
|
||||
with self.assertRaisesRegexp(ValueError, "seq_dim must be < input rank"):
|
||||
array_ops.reverse_sequence(
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(32, 2, 3)),
|
||||
seq_lengths=array_ops.placeholder(
|
||||
dtypes.int64, shape=(32,)),
|
||||
seq_axis=3)
|
||||
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||
"seq_dim must be < input rank"):
|
||||
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=2)
|
||||
|
||||
# batch_axis out of bounds.
|
||||
with self.assertRaisesRegexp(ValueError, "batch_dim must be < input rank"):
|
||||
array_ops.reverse_sequence(
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(32, 2, 3)),
|
||||
seq_lengths=array_ops.placeholder(
|
||||
dtypes.int64, shape=(32,)),
|
||||
seq_axis=0,
|
||||
batch_axis=3)
|
||||
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||
"batch_dim must be < input rank"):
|
||||
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2],
|
||||
seq_axis=1,
|
||||
batch_axis=3)
|
||||
|
||||
with self.cached_session():
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3))
|
||||
seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,))
|
||||
output = array_ops.reverse_sequence(
|
||||
inputs, seq_lengths=seq_lengths,
|
||||
seq_axis=0) # batch_axis default is 0
|
||||
with self.assertRaisesOpError("batch_dim == seq_dim"):
|
||||
output.eval(feed_dict={
|
||||
inputs: np.random.rand(32, 2, 3),
|
||||
seq_lengths: xrange(32)
|
||||
})
|
||||
with self.assertRaisesRegexp((errors.OpError, errors.InvalidArgumentError),
|
||||
"batch_dim == seq_dim == 0"):
|
||||
output = array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=0)
|
||||
self.evaluate(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4473,8 +4473,8 @@ def reverse_sequence(input,
|
||||
dimension `seq_axis`.
|
||||
|
||||
The elements of `seq_lengths` must obey `seq_lengths[i] <=
|
||||
input.dims[seq_dim]`, and `seq_lengths` must be a vector of length
|
||||
`input.dims[batch_dim]`.
|
||||
input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
|
||||
`input.dims[batch_axis]`.
|
||||
|
||||
The output slice `i` along dimension `batch_axis` is then given by
|
||||
input slice `i`, with the first `seq_lengths[i]` slices along
|
||||
@ -4496,8 +4496,8 @@ def reverse_sequence(input,
|
||||
Args:
|
||||
input: A `Tensor`. The input to reverse.
|
||||
seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
|
||||
`int64`. 1-D with length `input.dims(batch_dim)` and `max(seq_lengths) <=
|
||||
input.dims(seq_dim)`
|
||||
`int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
|
||||
input.dims(seq_axis)`
|
||||
seq_axis: An `int`. The dimension which is partially reversed.
|
||||
batch_axis: An optional `int`. Defaults to `0`. The dimension along which
|
||||
reversal is performed.
|
||||
|
Loading…
Reference in New Issue
Block a user