Ensure there are test samples for imdb dataset, when maxlen is low

With the current imdb.load_data(), the following results are seen
for different values of maxlen.

	load_data                (len(x_train), len(x_test))
------------------------------------------------------------
imdb.load_data(maxlen=50)    -->    (1035, 0)
imdb.load_data(maxlen=100)   -->    (5736, 0)
imdb.load_data(maxlen=200)   -->    (25000, 3913)
imdb.load_data()             -->    (25000, 25000)

Analysis: We can observe that when maxlen is low, the number
of test samples can be 0. This is because the train and test data is
concatenated, then the samples with length > maxlen are removed, and
the first 25,000 are considered as training data.

Fix: This can be fixed when data can be filtered first to remove the
ones with length > maxlen, and then concatenate to process further.
The following are the results after the fix.

     fixed load_data              (len(x_train), len(x_test))
------------------------------------------------------------
imdb.load_data(maxlen=50)    -->    (477, 558)
imdb.load_data(maxlen=100)   -->    (2773, 2963)
imdb.load_data(maxlen=200)   -->    (14244, 14669)
imdb.load_data()             -->    (25000, 25000)
This commit is contained in:
Devi Sandeep Endluri 2020-06-19 06:23:39 -05:00
parent e972c55726
commit 7b1a726ec3

View File

@ -124,20 +124,24 @@ def load_data(path='imdb.npz',
x_test = x_test[indices]
labels_test = labels_test[indices]
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
x_train = [[w + index_from for w in x] for x in x_train]
x_test = [[w + index_from for w in x] for x in x_test]
if maxlen:
xs, labels = _remove_long_seq(maxlen, xs, labels)
if not xs:
x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
if not x_train or not x_test:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
'Increase maxlen.')
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if not num_words:
num_words = max(max(x) for x in xs)