Prevent Keras dataset loading from affecting the global RNG seed.
PiperOrigin-RevId: 316527944 Change-Id: I13fc997ffafc02f25b94e45265c7aa97b6efc6c4
This commit is contained in:
parent
540852285d
commit
18e0e6450b
@ -67,9 +67,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
|||||||
x = f['x']
|
x = f['x']
|
||||||
y = f['y']
|
y = f['y']
|
||||||
|
|
||||||
np.random.seed(seed)
|
rng = np.random.RandomState(seed)
|
||||||
indices = np.arange(len(x))
|
indices = np.arange(len(x))
|
||||||
np.random.shuffle(indices)
|
rng.shuffle(indices)
|
||||||
x = x[indices]
|
x = x[indices]
|
||||||
y = y[indices]
|
y = y[indices]
|
||||||
|
|
||||||
|
@ -119,9 +119,9 @@ def load_data(path='reuters.npz',
|
|||||||
with np.load(path, allow_pickle=True) as f:
|
with np.load(path, allow_pickle=True) as f:
|
||||||
xs, labels = f['x'], f['y']
|
xs, labels = f['x'], f['y']
|
||||||
|
|
||||||
np.random.seed(seed)
|
rng = np.random.RandomState(seed)
|
||||||
indices = np.arange(len(xs))
|
indices = np.arange(len(xs))
|
||||||
np.random.shuffle(indices)
|
rng.shuffle(indices)
|
||||||
xs = xs[indices]
|
xs = xs[indices]
|
||||||
labels = labels[indices]
|
labels = labels[indices]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user