Fix argument check tests to work in eager mode
PiperOrigin-RevId: 312170271 Change-Id: Ie7ffb52cf63559255b5463d651eb72b924a3c3bf
This commit is contained in:
parent
94108993a3
commit
1acf6989bf
@ -43,9 +43,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
template <typename Device, typename Tlen>
|
template <typename Device, typename Tlen>
|
||||||
void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const Tensor& seq_lens = context->input(1);
|
const Tensor& seq_lengths = context->input(1);
|
||||||
|
|
||||||
auto seq_lens_t = seq_lens.vec<Tlen>();
|
auto seq_lens_t = seq_lengths.vec<Tlen>();
|
||||||
|
|
||||||
std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
|
std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
|
||||||
|
|
||||||
@ -56,15 +56,16 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
|||||||
OP_REQUIRES(context, batch_dim != seq_dim,
|
OP_REQUIRES(context, batch_dim != seq_dim,
|
||||||
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
||||||
OP_REQUIRES(context, seq_dim < input.dims(),
|
OP_REQUIRES(context, seq_dim < input.dims(),
|
||||||
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
errors::InvalidArgument("seq_dim must be < input rank", " ( ",
|
||||||
seq_dim, " vs. ", input.dims(), ")"));
|
seq_dim, " vs. ", input.dims(), ")"));
|
||||||
OP_REQUIRES(context, batch_dim < input.dims(),
|
OP_REQUIRES(context, batch_dim < input.dims(),
|
||||||
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
errors::InvalidArgument("batch_dim must be < input rank", " ( ",
|
||||||
batch_dim, " vs. ", input.dims(), ")"));
|
batch_dim, " vs. ", input.dims(), ")"));
|
||||||
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
|
context, seq_lengths.NumElements() == input.dim_size(batch_dim),
|
||||||
"), ", "(", seq_lens.NumElements(),
|
errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
|
||||||
" vs. ", input.dim_size(batch_dim), ")"));
|
"), ", "(", seq_lengths.NumElements(), " vs. ",
|
||||||
|
input.dim_size(batch_dim), ")"));
|
||||||
|
|
||||||
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
|
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
|
||||||
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
|
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
|
||||||
@ -77,21 +78,22 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
|
|||||||
|
|
||||||
void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
|
void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const Tensor& seq_lens = context->input(1);
|
const Tensor& seq_lengths = context->input(1);
|
||||||
|
|
||||||
OP_REQUIRES(context, batch_dim != seq_dim,
|
OP_REQUIRES(context, batch_dim != seq_dim,
|
||||||
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
|
||||||
OP_REQUIRES(context, seq_dim < input.dims(),
|
OP_REQUIRES(context, seq_dim < input.dims(),
|
||||||
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
errors::InvalidArgument("seq_dim must be < input rank", " ( ",
|
||||||
seq_dim, " vs. ", input.dims(), ")"));
|
seq_dim, " vs. ", input.dims(), ")"));
|
||||||
OP_REQUIRES(context, batch_dim < input.dims(),
|
OP_REQUIRES(context, batch_dim < input.dims(),
|
||||||
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
errors::InvalidArgument("batch_dim must be < input rank", " ( ",
|
||||||
batch_dim, " vs. ", input.dims(), ")"));
|
batch_dim, " vs. ", input.dims(), ")"));
|
||||||
|
|
||||||
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
|
context, seq_lengths.NumElements() == input.dim_size(batch_dim),
|
||||||
"), ", "(", seq_lens.NumElements(),
|
errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
|
||||||
" vs. ", input.dim_size(batch_dim), ")"));
|
"), ", "(", seq_lengths.NumElements(), " vs. ",
|
||||||
|
input.dim_size(batch_dim), ")"));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -117,14 +119,14 @@ class ReverseSequenceOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const Tensor& seq_lens = context->input(1);
|
const Tensor& seq_lengths = context->input(1);
|
||||||
|
|
||||||
// Preliminary validation of sizes.
|
// Preliminary validation of sizes.
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
|
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()),
|
||||||
errors::InvalidArgument("seq_lens input must be 1-dim, not ",
|
errors::InvalidArgument("seq_lengths must be 1-dim, not ",
|
||||||
seq_lens.dims()));
|
seq_lengths.dims()));
|
||||||
|
|
||||||
auto seq_lens_t = seq_lens.vec<Tlen>();
|
auto seq_lens_t = seq_lengths.vec<Tlen>();
|
||||||
|
|
||||||
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
|
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
|
||||||
if (!context->status().ok()) return;
|
if (!context->status().ok()) return;
|
||||||
@ -186,7 +188,7 @@ namespace functor {
|
|||||||
void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
|
void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
|
||||||
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
|
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
|
||||||
int32 batch_dim, int32 seq_dim, \
|
int32 batch_dim, int32 seq_dim, \
|
||||||
typename TTypes<Tlen>::ConstVec seq_lens, \
|
typename TTypes<Tlen>::ConstVec seq_lengths, \
|
||||||
typename TTypes<T, Dims>::Tensor output); \
|
typename TTypes<T, Dims>::Tensor output); \
|
||||||
extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
|
extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
|
||||||
|
|
||||||
|
@ -19,10 +19,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
@ -135,56 +136,52 @@ class ReverseSequenceTest(test.TestCase):
|
|||||||
print("ReverseSequence gradient error = %g" % err)
|
print("ReverseSequence gradient error = %g" % err)
|
||||||
self.assertLess(err, 1e-8)
|
self.assertLess(err, 1e-8)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testShapeFunctionEdgeCases(self):
|
def testShapeFunctionEdgeCases(self):
|
||||||
t = array_ops.reverse_sequence(
|
# Enter graph mode since we want to test partial shapes
|
||||||
array_ops.placeholder(
|
with context.graph_mode():
|
||||||
dtypes.float32, shape=None),
|
t = array_ops.reverse_sequence(
|
||||||
seq_lengths=array_ops.placeholder(
|
array_ops.placeholder(dtypes.float32, shape=None),
|
||||||
dtypes.int64, shape=(32,)),
|
seq_lengths=array_ops.placeholder(dtypes.int64, shape=(32,)),
|
||||||
batch_axis=0,
|
batch_axis=0,
|
||||||
seq_axis=1)
|
seq_axis=1)
|
||||||
self.assertIs(t.get_shape().ndims, None)
|
self.assertIs(t.get_shape().ndims, None)
|
||||||
|
|
||||||
|
def testInvalidArguments(self):
|
||||||
# Batch size mismatched between input and seq_lengths.
|
# Batch size mismatched between input and seq_lengths.
|
||||||
with self.assertRaises(ValueError):
|
# seq_length too long
|
||||||
array_ops.reverse_sequence(
|
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||||
array_ops.placeholder(
|
(r"Dimensions must be equal|"
|
||||||
dtypes.float32, shape=(32, 2, 3)),
|
r"Length of seq_lengths != input.dims\(0\)")):
|
||||||
seq_lengths=array_ops.placeholder(
|
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2, 2], seq_axis=1)
|
||||||
dtypes.int64, shape=(33,)),
|
|
||||||
seq_axis=3)
|
# seq_length too short
|
||||||
|
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||||
|
(r"Dimensions must be equal|"
|
||||||
|
r"Length of seq_lengths != input.dims\(0\)")):
|
||||||
|
array_ops.reverse_sequence([[1, 2], [3, 4]], [2], seq_axis=1)
|
||||||
|
|
||||||
|
# Invalid seq_length shape
|
||||||
|
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||||
|
("Shape must be rank 1 but is rank 2|"
|
||||||
|
"seq_lengths must be 1-dim")):
|
||||||
|
array_ops.reverse_sequence([[1, 2], [3, 4]], [[2, 2]], seq_axis=1)
|
||||||
|
|
||||||
# seq_axis out of bounds.
|
# seq_axis out of bounds.
|
||||||
with self.assertRaisesRegexp(ValueError, "seq_dim must be < input rank"):
|
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||||
array_ops.reverse_sequence(
|
"seq_dim must be < input rank"):
|
||||||
array_ops.placeholder(
|
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=2)
|
||||||
dtypes.float32, shape=(32, 2, 3)),
|
|
||||||
seq_lengths=array_ops.placeholder(
|
|
||||||
dtypes.int64, shape=(32,)),
|
|
||||||
seq_axis=3)
|
|
||||||
|
|
||||||
# batch_axis out of bounds.
|
# batch_axis out of bounds.
|
||||||
with self.assertRaisesRegexp(ValueError, "batch_dim must be < input rank"):
|
with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError),
|
||||||
array_ops.reverse_sequence(
|
"batch_dim must be < input rank"):
|
||||||
array_ops.placeholder(
|
array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2],
|
||||||
dtypes.float32, shape=(32, 2, 3)),
|
seq_axis=1,
|
||||||
seq_lengths=array_ops.placeholder(
|
batch_axis=3)
|
||||||
dtypes.int64, shape=(32,)),
|
|
||||||
seq_axis=0,
|
|
||||||
batch_axis=3)
|
|
||||||
|
|
||||||
with self.cached_session():
|
with self.assertRaisesRegexp((errors.OpError, errors.InvalidArgumentError),
|
||||||
inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3))
|
"batch_dim == seq_dim == 0"):
|
||||||
seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,))
|
output = array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=0)
|
||||||
output = array_ops.reverse_sequence(
|
self.evaluate(output)
|
||||||
inputs, seq_lengths=seq_lengths,
|
|
||||||
seq_axis=0) # batch_axis default is 0
|
|
||||||
with self.assertRaisesOpError("batch_dim == seq_dim"):
|
|
||||||
output.eval(feed_dict={
|
|
||||||
inputs: np.random.rand(32, 2, 3),
|
|
||||||
seq_lengths: xrange(32)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -4473,8 +4473,8 @@ def reverse_sequence(input,
|
|||||||
dimension `seq_axis`.
|
dimension `seq_axis`.
|
||||||
|
|
||||||
The elements of `seq_lengths` must obey `seq_lengths[i] <=
|
The elements of `seq_lengths` must obey `seq_lengths[i] <=
|
||||||
input.dims[seq_dim]`, and `seq_lengths` must be a vector of length
|
input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
|
||||||
`input.dims[batch_dim]`.
|
`input.dims[batch_axis]`.
|
||||||
|
|
||||||
The output slice `i` along dimension `batch_axis` is then given by
|
The output slice `i` along dimension `batch_axis` is then given by
|
||||||
input slice `i`, with the first `seq_lengths[i]` slices along
|
input slice `i`, with the first `seq_lengths[i]` slices along
|
||||||
@ -4496,8 +4496,8 @@ def reverse_sequence(input,
|
|||||||
Args:
|
Args:
|
||||||
input: A `Tensor`. The input to reverse.
|
input: A `Tensor`. The input to reverse.
|
||||||
seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
|
seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
|
||||||
`int64`. 1-D with length `input.dims(batch_dim)` and `max(seq_lengths) <=
|
`int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
|
||||||
input.dims(seq_dim)`
|
input.dims(seq_axis)`
|
||||||
seq_axis: An `int`. The dimension which is partially reversed.
|
seq_axis: An `int`. The dimension which is partially reversed.
|
||||||
batch_axis: An optional `int`. Defaults to `0`. The dimension along which
|
batch_axis: An optional `int`. Defaults to `0`. The dimension along which
|
||||||
reversal is performed.
|
reversal is performed.
|
||||||
|
Loading…
Reference in New Issue
Block a user