diff --git a/tensorflow/python/keras/datasets/imdb.py b/tensorflow/python/keras/datasets/imdb.py index 37403228edf..e359d691a5d 100644 --- a/tensorflow/python/keras/datasets/imdb.py +++ b/tensorflow/python/keras/datasets/imdb.py @@ -124,20 +124,24 @@ def load_data(path='imdb.npz', x_test = x_test[indices] labels_test = labels_test[indices] - xs = np.concatenate([x_train, x_test]) - labels = np.concatenate([labels_train, labels_test]) - if start_char is not None: - xs = [[start_char] + [w + index_from for w in x] for x in xs] + x_train = [[start_char] + [w + index_from for w in x] for x in x_train] + x_test = [[start_char] + [w + index_from for w in x] for x in x_test] elif index_from: - xs = [[w + index_from for w in x] for x in xs] + x_train = [[w + index_from for w in x] for x in x_train] + x_test = [[w + index_from for w in x] for x in x_test] if maxlen: - xs, labels = _remove_long_seq(maxlen, xs, labels) - if not xs: + x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train) + x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test) + if not x_train or not x_test: raise ValueError('After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. ' 'Increase maxlen.') + + xs = np.concatenate([x_train, x_test]) + labels = np.concatenate([labels_train, labels_test]) + if not num_words: num_words = max(max(x) for x in xs)