Fixes tf.reverse_sequence()
to work when input shape is unknown.
Fixes #1816. Change: 119369338
This commit is contained in:
parent
6cdcc59451
commit
91969ae68e
@ -114,6 +114,12 @@ class ReverseSequenceTest(tf.test.TestCase):
|
|||||||
self.assertLess(err, 1e-8)
|
self.assertLess(err, 1e-8)
|
||||||
|
|
||||||
def testShapeFunctionEdgeCases(self):
|
def testShapeFunctionEdgeCases(self):
|
||||||
|
t = tf.reverse_sequence(
|
||||||
|
tf.placeholder(tf.float32, shape=None),
|
||||||
|
seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
|
||||||
|
batch_dim=0, seq_dim=1)
|
||||||
|
self.assertIs(t.get_shape().ndims, None)
|
||||||
|
|
||||||
# Batch size mismatched between input and seq_lengths.
|
# Batch size mismatched between input and seq_lengths.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.reverse_sequence(
|
tf.reverse_sequence(
|
||||||
|
@ -1254,14 +1254,17 @@ def _ReverseSequenceShape(op):
|
|||||||
"""
|
"""
|
||||||
input_shape = op.inputs[0].get_shape()
|
input_shape = op.inputs[0].get_shape()
|
||||||
seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
|
seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
|
||||||
|
if input_shape.ndims is None:
|
||||||
|
return [None]
|
||||||
seq_dim = op.get_attr("seq_dim")
|
seq_dim = op.get_attr("seq_dim")
|
||||||
batch_dim = op.get_attr("batch_dim")
|
batch_dim = op.get_attr("batch_dim")
|
||||||
if batch_dim >= input_shape.ndims:
|
if input_shape.ndims is not None:
|
||||||
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
|
if batch_dim >= input_shape.ndims:
|
||||||
(batch_dim, input_shape.ndims))
|
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
|
||||||
if seq_dim >= input_shape.ndims:
|
(batch_dim, input_shape.ndims))
|
||||||
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
|
if seq_dim >= input_shape.ndims:
|
||||||
(seq_dim, input_shape.ndims))
|
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
|
||||||
|
(seq_dim, input_shape.ndims))
|
||||||
batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0])
|
batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0])
|
||||||
input_shape = tensor_shape.TensorShape([
|
input_shape = tensor_shape.TensorShape([
|
||||||
value if ix != batch_dim else batch_size
|
value if ix != batch_dim else batch_size
|
||||||
|
Loading…
Reference in New Issue
Block a user