Ensure there are test samples for imdb dataset, when maxlen is low
With the current imdb.load_data(), the following results are seen for different values of maxlen. load_data (len(x_train), len(x_test)) ------------------------------------------------------------ imdb.load_data(maxlen=50) --> (1035, 0) imdb.load_data(maxlen=100) --> (5736, 0) imdb.load_data(maxlen=200) --> (25000, 3913) imdb.load_data() --> (25000, 25000) Analysis: We can observe that when maxlen is low, the number of test samples can be 0. This is because the train and test data is concatenated, then the samples with length > maxlen are removed, and the first 25,000 are considered as training data. Fix: This can be fixed when data can be filtered first to remove the ones with length > maxlen, and then concatenate to process further. The following are the results after the fix. fixed load_data (len(x_train), len(x_test)) ------------------------------------------------------------ imdb.load_data(maxlen=50) --> (477, 558) imdb.load_data(maxlen=100) --> (2773, 2963) imdb.load_data(maxlen=200) --> (14244, 14669) imdb.load_data() --> (25000, 25000)
This commit is contained in:
parent
e972c55726
commit
7b1a726ec3
@ -124,20 +124,24 @@ def load_data(path='imdb.npz',
|
||||
x_test = x_test[indices]
|
||||
labels_test = labels_test[indices]
|
||||
|
||||
xs = np.concatenate([x_train, x_test])
|
||||
labels = np.concatenate([labels_train, labels_test])
|
||||
|
||||
if start_char is not None:
|
||||
xs = [[start_char] + [w + index_from for w in x] for x in xs]
|
||||
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
|
||||
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
|
||||
elif index_from:
|
||||
xs = [[w + index_from for w in x] for x in xs]
|
||||
x_train = [[w + index_from for w in x] for x in x_train]
|
||||
x_test = [[w + index_from for w in x] for x in x_test]
|
||||
|
||||
if maxlen:
|
||||
xs, labels = _remove_long_seq(maxlen, xs, labels)
|
||||
if not xs:
|
||||
x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
|
||||
x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
|
||||
if not x_train or not x_test:
|
||||
raise ValueError('After filtering for sequences shorter than maxlen=' +
|
||||
str(maxlen) + ', no sequence was kept. '
|
||||
'Increase maxlen.')
|
||||
|
||||
xs = np.concatenate([x_train, x_test])
|
||||
labels = np.concatenate([labels_train, labels_test])
|
||||
|
||||
if not num_words:
|
||||
num_words = max(max(x) for x in xs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user