Use tf.shape to get max_time inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.

PiperOrigin-RevId: 209829459
This commit is contained in:
A. Unique TensorFlower 2018-08-22 14:23:35 -07:00 committed by TensorFlower Gardener
parent 5022fc95aa
commit 915fd68aa4

View File

@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
# TODO(drpng): just use Update so that we don't carry over the gradients? # TODO(drpng): just use Update so that we don't carry over the gradients?
"""Sets the output to be zero at the end of the sequence.""" """Sets the output to be zero at the end of the sequence."""
# output is batch major. # output is batch major.
batch_size, max_time, vector_size = tf_output.shape shape = array_ops.shape(tf_output)
batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
output_time = array_ops.reshape(output_time, [batch_size, max_time]) output_time = array_ops.reshape(output_time, [batch_size, max_time])
lengths = array_ops.tile( lengths = array_ops.tile(