From 91969ae68e30eef368698094a7f574a3bcfdf315 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 8 Apr 2016 06:04:37 -0800 Subject: [PATCH] Fixes `tf.reverse_sequence()` to work when input shape is unknown. Fixes #1816. Change: 119369338 --- .../kernel_tests/reverse_sequence_op_test.py | 6 ++++++ tensorflow/python/ops/array_ops.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py index d5b087298e1..ccb19009fbd 100644 --- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py +++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py @@ -114,6 +114,12 @@ class ReverseSequenceTest(tf.test.TestCase): self.assertLess(err, 1e-8) def testShapeFunctionEdgeCases(self): + t = tf.reverse_sequence( + tf.placeholder(tf.float32, shape=None), + seq_lengths=tf.placeholder(tf.int64, shape=(32,)), + batch_dim=0, seq_dim=1) + self.assertIs(t.get_shape().ndims, None) + # Batch size mismatched between input and seq_lengths. with self.assertRaises(ValueError): tf.reverse_sequence( diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 06f7fd444ed..0cb751553c4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1254,14 +1254,17 @@ def _ReverseSequenceShape(op): """ input_shape = op.inputs[0].get_shape() seq_lens_shape = op.inputs[1].get_shape().with_rank(1) + if input_shape.ndims is None: + return [None] seq_dim = op.get_attr("seq_dim") batch_dim = op.get_attr("batch_dim") - if batch_dim >= input_shape.ndims: - raise ValueError("batch_dim must be < input.dims() (%d vs %d)" % - (batch_dim, input_shape.ndims)) - if seq_dim >= input_shape.ndims: - raise ValueError("seq_dim must be < input.dims() (%d vs %d)" % - (seq_dim, input_shape.ndims)) + if input_shape.ndims is not None: + if batch_dim >= input_shape.ndims: + raise ValueError("batch_dim must be < input.dims() (%d vs %d)" % + (batch_dim, input_shape.ndims)) + if seq_dim >= input_shape.ndims: + raise ValueError("seq_dim must be < input.dims() (%d vs %d)" % + (seq_dim, input_shape.ndims)) batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0]) input_shape = tensor_shape.TensorShape([ value if ix != batch_dim else batch_size