diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 0af72231b00..3c225abbbfc 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -294,10 +294,6 @@ def _reverse_seq(input_seq, lengths): # Join into (time, batch_size, depth) s_joined = array_ops.stack(sequence) - # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 - if lengths is not None: - lengths = math_ops.to_int64(lengths) - # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list