Prevent Keras dataset loading from affecting the global RNG
PiperOrigin-RevId: 316533425 Change-Id: I6099847f9a7ead24786fb2fecd5ba488f53456e6
This commit is contained in:
parent
df4ea0c1a5
commit
52736a6adc
@ -113,14 +113,14 @@ def load_data(path='imdb.npz',
|
||||
x_train, labels_train = f['x_train'], f['y_train']
|
||||
x_test, labels_test = f['x_test'], f['y_test']
|
||||
|
||||
np.random.seed(seed)
|
||||
rng = np.random.RandomState(seed)
|
||||
indices = np.arange(len(x_train))
|
||||
np.random.shuffle(indices)
|
||||
rng.shuffle(indices)
|
||||
x_train = x_train[indices]
|
||||
labels_train = labels_train[indices]
|
||||
|
||||
indices = np.arange(len(x_test))
|
||||
np.random.shuffle(indices)
|
||||
rng.shuffle(indices)
|
||||
x_test = x_test[indices]
|
||||
labels_test = labels_test[indices]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user