Merge pull request #40610 from dsandeep0138:maxlen-issue
PiperOrigin-RevId: 317669885 Change-Id: Ie727ac3f6238643dcd6065c073a4e8cbb2757077
This commit is contained in:
commit
8498c64d77
@ -124,20 +124,24 @@ def load_data(path='imdb.npz',
|
||||
x_test = x_test[indices]
|
||||
labels_test = labels_test[indices]
|
||||
|
||||
xs = np.concatenate([x_train, x_test])
|
||||
labels = np.concatenate([labels_train, labels_test])
|
||||
|
||||
if start_char is not None:
|
||||
xs = [[start_char] + [w + index_from for w in x] for x in xs]
|
||||
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
|
||||
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
|
||||
elif index_from:
|
||||
xs = [[w + index_from for w in x] for x in xs]
|
||||
x_train = [[w + index_from for w in x] for x in x_train]
|
||||
x_test = [[w + index_from for w in x] for x in x_test]
|
||||
|
||||
if maxlen:
|
||||
xs, labels = _remove_long_seq(maxlen, xs, labels)
|
||||
if not xs:
|
||||
x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
|
||||
x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
|
||||
if not x_train or not x_test:
|
||||
raise ValueError('After filtering for sequences shorter than maxlen=' +
|
||||
str(maxlen) + ', no sequence was kept. '
|
||||
'Increase maxlen.')
|
||||
|
||||
xs = np.concatenate([x_train, x_test])
|
||||
labels = np.concatenate([labels_train, labels_test])
|
||||
|
||||
if not num_words:
|
||||
num_words = max(max(x) for x in xs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user