Use tf.shape to get max_time inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.

PiperOrigin-RevId: 209829459
This commit is contained in:
A. Unique TensorFlower 2018-08-22 14:23:35 -07:00 committed by TensorFlower Gardener
parent 5022fc95aa
commit 915fd68aa4

View File

@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
# TODO(drpng): just use Update so that we don't carry over the gradients?
"""Sets the output to be zero at the end of the sequence."""
# output is batch major.
batch_size, max_time, vector_size = tf_output.shape
shape = array_ops.shape(tf_output)
batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
output_time = array_ops.reshape(output_time, [batch_size, max_time])
lengths = array_ops.tile(