Use tf.shape to get max_time
inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.
PiperOrigin-RevId: 209829459
This commit is contained in:
parent
5022fc95aa
commit
915fd68aa4
@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
|
||||
# TODO(drpng): just use Update so that we don't carry over the gradients?
|
||||
"""Sets the output to be zero at the end of the sequence."""
|
||||
# output is batch major.
|
||||
batch_size, max_time, vector_size = tf_output.shape
|
||||
shape = array_ops.shape(tf_output)
|
||||
batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
|
||||
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
|
||||
output_time = array_ops.reshape(output_time, [batch_size, max_time])
|
||||
lengths = array_ops.tile(
|
||||
|
Loading…
Reference in New Issue
Block a user