Fixes tf.reverse_sequence()
to work when input shape is unknown.
Fixes #1816. Change: 119369338
This commit is contained in:
parent
6cdcc59451
commit
91969ae68e
@ -114,6 +114,12 @@ class ReverseSequenceTest(tf.test.TestCase):
|
||||
self.assertLess(err, 1e-8)
|
||||
|
||||
def testShapeFunctionEdgeCases(self):
|
||||
t = tf.reverse_sequence(
|
||||
tf.placeholder(tf.float32, shape=None),
|
||||
seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
|
||||
batch_dim=0, seq_dim=1)
|
||||
self.assertIs(t.get_shape().ndims, None)
|
||||
|
||||
# Batch size mismatched between input and seq_lengths.
|
||||
with self.assertRaises(ValueError):
|
||||
tf.reverse_sequence(
|
||||
|
@ -1254,14 +1254,17 @@ def _ReverseSequenceShape(op):
|
||||
"""
|
||||
input_shape = op.inputs[0].get_shape()
|
||||
seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
|
||||
if input_shape.ndims is None:
|
||||
return [None]
|
||||
seq_dim = op.get_attr("seq_dim")
|
||||
batch_dim = op.get_attr("batch_dim")
|
||||
if batch_dim >= input_shape.ndims:
|
||||
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
|
||||
(batch_dim, input_shape.ndims))
|
||||
if seq_dim >= input_shape.ndims:
|
||||
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
|
||||
(seq_dim, input_shape.ndims))
|
||||
if input_shape.ndims is not None:
|
||||
if batch_dim >= input_shape.ndims:
|
||||
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
|
||||
(batch_dim, input_shape.ndims))
|
||||
if seq_dim >= input_shape.ndims:
|
||||
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
|
||||
(seq_dim, input_shape.ndims))
|
||||
batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0])
|
||||
input_shape = tensor_shape.TensorShape([
|
||||
value if ix != batch_dim else batch_size
|
||||
|
Loading…
Reference in New Issue
Block a user