Use tf.shape to get max_time
inside _ApplyLengthsToBatch in case the tensor is dynamic shaped.
PiperOrigin-RevId: 209829459
This commit is contained in:
parent
5022fc95aa
commit
915fd68aa4
@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
|
|||||||
# TODO(drpng): just use Update so that we don't carry over the gradients?
|
# TODO(drpng): just use Update so that we don't carry over the gradients?
|
||||||
"""Sets the output to be zero at the end of the sequence."""
|
"""Sets the output to be zero at the end of the sequence."""
|
||||||
# output is batch major.
|
# output is batch major.
|
||||||
batch_size, max_time, vector_size = tf_output.shape
|
shape = array_ops.shape(tf_output)
|
||||||
|
batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
|
||||||
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
|
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
|
||||||
output_time = array_ops.reshape(output_time, [batch_size, max_time])
|
output_time = array_ops.reshape(output_time, [batch_size, max_time])
|
||||||
lengths = array_ops.tile(
|
lengths = array_ops.tile(
|
||||||
|
Loading…
Reference in New Issue
Block a user