Fix too long sentence (#4962)
* Fix too long sentence * Make previous commit more idiomatic * Fix missing logging import
This commit is contained in:
parent
1511dd4bad
commit
e7066fb9c1
@ -36,6 +36,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
@ -238,8 +239,14 @@ def decode():
|
|||||||
# Get token-ids for the input sentence.
|
# Get token-ids for the input sentence.
|
||||||
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
|
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
|
||||||
# Which bucket does it belong to?
|
# Which bucket does it belong to?
|
||||||
bucket_id = min([b for b in xrange(len(_buckets))
|
bucket_id = len(_buckets) - 1
|
||||||
if _buckets[b][0] > len(token_ids)])
|
for i, bucket in enumerate(_buckets):
|
||||||
|
if bucket[0] >= len(token_ids):
|
||||||
|
bucket_id = i
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.warning("Sentence truncated: %s", sentence)
|
||||||
|
|
||||||
# Get a 1-element batch to feed the sentence to the model.
|
# Get a 1-element batch to feed the sentence to the model.
|
||||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||||
{bucket_id: [(token_ids, [])]}, bucket_id)
|
{bucket_id: [(token_ids, [])]}, bucket_id)
|
||||||
|
Loading…
Reference in New Issue
Block a user