Fixes tf.reverse_sequence() to work when input shape is unknown.

Fixes #1816.
Change: 119369338
This commit is contained in:
Derek Murray 2016-04-08 06:04:37 -08:00 committed by TensorFlower Gardener
parent 6cdcc59451
commit 91969ae68e
2 changed files with 15 additions and 6 deletions

View File

@ -114,6 +114,12 @@ class ReverseSequenceTest(tf.test.TestCase):
self.assertLess(err, 1e-8)
def testShapeFunctionEdgeCases(self):
t = tf.reverse_sequence(
tf.placeholder(tf.float32, shape=None),
seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
batch_dim=0, seq_dim=1)
self.assertIs(t.get_shape().ndims, None)
# Batch size mismatched between input and seq_lengths.
with self.assertRaises(ValueError):
tf.reverse_sequence(

View File

@ -1254,14 +1254,17 @@ def _ReverseSequenceShape(op):
"""
input_shape = op.inputs[0].get_shape()
seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
if input_shape.ndims is None:
return [None]
seq_dim = op.get_attr("seq_dim")
batch_dim = op.get_attr("batch_dim")
if batch_dim >= input_shape.ndims:
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
(batch_dim, input_shape.ndims))
if seq_dim >= input_shape.ndims:
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
(seq_dim, input_shape.ndims))
if input_shape.ndims is not None:
if batch_dim >= input_shape.ndims:
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
(batch_dim, input_shape.ndims))
if seq_dim >= input_shape.ndims:
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
(seq_dim, input_shape.ndims))
batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0])
input_shape = tensor_shape.TensorShape([
value if ix != batch_dim else batch_size