diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py index 083417ba5ae..cf870c71f6e 100644 --- a/tensorflow/models/rnn/translate/translate.py +++ b/tensorflow/models/rnn/translate/translate.py @@ -36,6 +36,7 @@ import os import random import sys import time +import logging import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -238,8 +239,14 @@ def decode(): # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab) # Which bucket does it belong to? - bucket_id = min([b for b in xrange(len(_buckets)) - if _buckets[b][0] > len(token_ids)]) + bucket_id = len(_buckets) - 1 + for i, bucket in enumerate(_buckets): + if bucket[0] >= len(token_ids): + bucket_id = i + break + else: + logging.warning("Sentence truncated: %s", sentence) + # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id)