Prevent Keras dataset loading from affecting the global RNG seed.
PiperOrigin-RevId: 316527944 Change-Id: I13fc997ffafc02f25b94e45265c7aa97b6efc6c4
This commit is contained in:
parent
540852285d
commit
18e0e6450b
tensorflow/python/keras/datasets
@ -67,9 +67,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
||||
x = f['x']
|
||||
y = f['y']
|
||||
|
||||
np.random.seed(seed)
|
||||
rng = np.random.RandomState(seed)
|
||||
indices = np.arange(len(x))
|
||||
np.random.shuffle(indices)
|
||||
rng.shuffle(indices)
|
||||
x = x[indices]
|
||||
y = y[indices]
|
||||
|
||||
|
@ -119,9 +119,9 @@ def load_data(path='reuters.npz',
|
||||
with np.load(path, allow_pickle=True) as f:
|
||||
xs, labels = f['x'], f['y']
|
||||
|
||||
np.random.seed(seed)
|
||||
rng = np.random.RandomState(seed)
|
||||
indices = np.arange(len(xs))
|
||||
np.random.shuffle(indices)
|
||||
rng.shuffle(indices)
|
||||
xs = xs[indices]
|
||||
labels = labels[indices]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user