Merge pull request #40610 from dsandeep0138:maxlen-issue

PiperOrigin-RevId: 317669885
Change-Id: Ie727ac3f6238643dcd6065c073a4e8cbb2757077
This commit is contained in:
TensorFlower Gardener 2020-06-22 09:52:13 -07:00
commit 8498c64d77

View File

@ -124,20 +124,24 @@ def load_data(path='imdb.npz',
x_test = x_test[indices]
labels_test = labels_test[indices]
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
x_train = [[w + index_from for w in x] for x in x_train]
x_test = [[w + index_from for w in x] for x in x_test]
if maxlen:
xs, labels = _remove_long_seq(maxlen, xs, labels)
if not xs:
x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
if not x_train or not x_test:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
'Increase maxlen.')
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if not num_words:
num_words = max(max(x) for x in xs)