Part 3/4 of the update of tf.keras to the 2.2.4 API.
PiperOrigin-RevId: 216639755
This commit is contained in:
parent
2b8f59243e
commit
96a633367e
tensorflow
@ -718,6 +718,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_generator_test",
|
||||
size = "enormous",
|
||||
srcs = ["engine/training_generator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "feature_columns_integration_test",
|
||||
size = "small",
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
|
||||
from tensorflow.python.keras.utils.data_utils import iter_sequence_infinite
|
||||
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
|
||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
@ -45,7 +46,6 @@ def fit_generator(model,
|
||||
shuffle=True,
|
||||
initial_epoch=0):
|
||||
"""See docstring for `Model.fit_generator`."""
|
||||
wait_time = 0.01 # in seconds
|
||||
epoch = initial_epoch
|
||||
|
||||
do_validation = bool(validation_data)
|
||||
@ -124,13 +124,12 @@ def fit_generator(model,
|
||||
else:
|
||||
enqueuer = GeneratorEnqueuer(
|
||||
generator,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
wait_time=wait_time)
|
||||
use_multiprocessing=use_multiprocessing)
|
||||
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
||||
output_generator = enqueuer.get()
|
||||
else:
|
||||
if is_sequence:
|
||||
output_generator = iter(generator)
|
||||
output_generator = iter_sequence_infinite(generator)
|
||||
else:
|
||||
output_generator = generator
|
||||
|
||||
@ -251,7 +250,6 @@ def evaluate_generator(model,
|
||||
stateful_metric_indices = []
|
||||
|
||||
steps_done = 0
|
||||
wait_time = 0.01
|
||||
all_outs = []
|
||||
batch_sizes = []
|
||||
is_sequence = isinstance(generator, Sequence)
|
||||
@ -279,13 +277,12 @@ def evaluate_generator(model,
|
||||
else:
|
||||
enqueuer = GeneratorEnqueuer(
|
||||
generator,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
wait_time=wait_time)
|
||||
use_multiprocessing=use_multiprocessing)
|
||||
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
||||
output_generator = enqueuer.get()
|
||||
else:
|
||||
if is_sequence:
|
||||
output_generator = iter(generator)
|
||||
output_generator = iter_sequence_infinite(generator)
|
||||
else:
|
||||
output_generator = generator
|
||||
|
||||
@ -354,7 +351,6 @@ def predict_generator(model,
|
||||
model._make_test_function()
|
||||
|
||||
steps_done = 0
|
||||
wait_time = 0.01
|
||||
all_outs = []
|
||||
is_sequence = isinstance(generator, Sequence)
|
||||
if not is_sequence and use_multiprocessing and workers > 1:
|
||||
@ -381,13 +377,12 @@ def predict_generator(model,
|
||||
else:
|
||||
enqueuer = GeneratorEnqueuer(
|
||||
generator,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
wait_time=wait_time)
|
||||
use_multiprocessing=use_multiprocessing)
|
||||
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
||||
output_generator = enqueuer.get()
|
||||
else:
|
||||
if is_sequence:
|
||||
output_generator = iter(generator)
|
||||
output_generator = iter_sequence_infinite(generator)
|
||||
else:
|
||||
output_generator = generator
|
||||
|
||||
|
307
tensorflow/python/keras/engine/training_generator_test.py
Normal file
307
tensorflow/python/keras/engine/training_generator_test.py
Normal file
@ -0,0 +1,307 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for training routines."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
class TestGeneratorMethods(test.TestCase):
|
||||
|
||||
@unittest.skipIf(
|
||||
os.name == 'nt',
|
||||
'use_multiprocessing=True does not work on windows properly.')
|
||||
def test_generator_methods(self):
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
yield x, y
|
||||
|
||||
with self.cached_session():
|
||||
x = keras.Input((2,))
|
||||
y = keras.layers.Dense(1)(x)
|
||||
fn_model = keras.models.Model(x, y)
|
||||
fn_model.compile(
|
||||
loss='mse',
|
||||
optimizer='sgd',
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||
|
||||
seq_model = keras.models.Sequential()
|
||||
seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
seq_model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
for model in [fn_model, seq_model]:
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
workers=4,
|
||||
use_multiprocessing=True)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
workers=0)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=2,
|
||||
use_multiprocessing=True)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=0)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=2,
|
||||
verbose=1,
|
||||
use_multiprocessing=True)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
workers=0)
|
||||
|
||||
def test_generator_methods_with_sample_weights(self):
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
arr_sample_weights = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
w = arr_sample_weights[start: end]
|
||||
yield x, y, w
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='sgd',
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
|
||||
def test_generator_methods_invalid_use_case(self):
|
||||
|
||||
def custom_generator():
|
||||
while 1:
|
||||
yield 0
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
with self.assertRaises(ValueError):
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
with self.assertRaises(AttributeError):
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
with self.assertRaises(ValueError):
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
|
||||
def test_training_with_sequences(self):
|
||||
|
||||
class DummySequence(keras.utils.Sequence):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.zeros([10, 2]), np.ones([10])
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
arr_sample_weights = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
w = arr_sample_weights[start: end]
|
||||
yield x, y, w
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
model.fit_generator(DummySequence(),
|
||||
steps_per_epoch=10,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
max_queue_size=10,
|
||||
workers=0,
|
||||
use_multiprocessing=True)
|
||||
model.fit_generator(DummySequence(),
|
||||
steps_per_epoch=10,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
max_queue_size=10,
|
||||
workers=0,
|
||||
use_multiprocessing=False)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_generator_input_to_fit_eval_predict(self):
|
||||
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
def custom_generator():
|
||||
while True:
|
||||
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
inputs = keras.layers.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10, activation='relu')(inputs)
|
||||
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
|
||||
model = keras.Model(inputs, outputs)
|
||||
|
||||
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
|
||||
model.fit(
|
||||
custom_generator(),
|
||||
steps_per_epoch=2,
|
||||
validation_data=val_data,
|
||||
epochs=2)
|
||||
model.evaluate(custom_generator(), steps=2)
|
||||
model.predict(custom_generator(), steps=2)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_sequence_input_to_fit_eval_predict(self):
|
||||
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
class CustomSequence(keras.utils.Sequence):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
def __len__(self):
|
||||
return 2
|
||||
|
||||
inputs = keras.layers.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10, activation='relu')(inputs)
|
||||
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
|
||||
model = keras.Model(inputs, outputs)
|
||||
|
||||
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
|
||||
model.fit(CustomSequence(), validation_data=val_data, epochs=2)
|
||||
model.evaluate(CustomSequence())
|
||||
model.predict(CustomSequence())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
|
||||
model.fit(CustomSequence(), y=np.ones([10, 1]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'`sample_weight` argument is not supported'):
|
||||
model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -19,8 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1102,279 +1100,6 @@ class TestDynamicTrainability(test.TestCase):
|
||||
self.assertListEqual(outer_model.trainable_weights, [])
|
||||
|
||||
|
||||
class TestGeneratorMethods(test.TestCase):
|
||||
|
||||
@unittest.skipIf(
|
||||
os.name == 'nt',
|
||||
'use_multiprocessing=True does not work on windows properly.')
|
||||
def test_generator_methods(self):
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
yield x, y
|
||||
|
||||
with self.cached_session():
|
||||
x = keras.Input((2,))
|
||||
y = keras.layers.Dense(1)(x)
|
||||
fn_model = keras.models.Model(x, y)
|
||||
fn_model.compile(
|
||||
loss='mse',
|
||||
optimizer='sgd',
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||
|
||||
seq_model = keras.models.Sequential()
|
||||
seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
seq_model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
for model in [fn_model, seq_model]:
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
workers=4,
|
||||
use_multiprocessing=True)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
workers=0)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=2,
|
||||
use_multiprocessing=True)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=0)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
workers=2,
|
||||
verbose=1,
|
||||
use_multiprocessing=True)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
workers=0)
|
||||
|
||||
def test_generator_methods_with_sample_weights(self):
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
arr_sample_weights = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
w = arr_sample_weights[start: end]
|
||||
yield x, y, w
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='sgd',
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
|
||||
def test_generator_methods_invalid_use_case(self):
|
||||
|
||||
def custom_generator():
|
||||
while 1:
|
||||
yield 0
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
with self.assertRaises(ValueError):
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=10)
|
||||
with self.assertRaises(AttributeError):
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
with self.assertRaises(ValueError):
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
use_multiprocessing=False)
|
||||
|
||||
def test_training_with_sequences(self):
|
||||
|
||||
class DummySequence(keras.utils.Sequence):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.zeros([10, 2]), np.ones([10])
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
arr_data = np.random.random((50, 2))
|
||||
arr_labels = np.random.random((50,))
|
||||
arr_sample_weights = np.random.random((50,))
|
||||
|
||||
def custom_generator():
|
||||
batch_size = 10
|
||||
num_samples = 50
|
||||
while True:
|
||||
batch_index = np.random.randint(0, num_samples - batch_size)
|
||||
start = batch_index
|
||||
end = start + batch_size
|
||||
x = arr_data[start: end]
|
||||
y = arr_labels[start: end]
|
||||
w = arr_sample_weights[start: end]
|
||||
yield x, y, w
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_shape=(2,)))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
model.fit_generator(DummySequence(),
|
||||
steps_per_epoch=10,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
max_queue_size=10,
|
||||
workers=0,
|
||||
use_multiprocessing=True)
|
||||
model.fit_generator(DummySequence(),
|
||||
steps_per_epoch=10,
|
||||
validation_data=custom_generator(),
|
||||
validation_steps=1,
|
||||
max_queue_size=10,
|
||||
workers=0,
|
||||
use_multiprocessing=False)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_generator_input_to_fit_eval_predict(self):
|
||||
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
def custom_generator():
|
||||
while True:
|
||||
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
inputs = keras.layers.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10, activation='relu')(inputs)
|
||||
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
|
||||
model = keras.Model(inputs, outputs)
|
||||
|
||||
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
|
||||
model.fit(
|
||||
custom_generator(),
|
||||
steps_per_epoch=2,
|
||||
validation_data=val_data,
|
||||
epochs=2)
|
||||
model.evaluate(custom_generator(), steps=2)
|
||||
model.predict(custom_generator(), steps=2)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_sequence_input_to_fit_eval_predict(self):
|
||||
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
class CustomSequence(keras.utils.Sequence):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
def __len__(self):
|
||||
return 2
|
||||
|
||||
inputs = keras.layers.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10, activation='relu')(inputs)
|
||||
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
|
||||
model = keras.Model(inputs, outputs)
|
||||
|
||||
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
|
||||
model.fit(CustomSequence(), validation_data=val_data, epochs=2)
|
||||
model.evaluate(CustomSequence())
|
||||
model.predict(CustomSequence())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
|
||||
model.fit(CustomSequence(), y=np.ones([10, 1]))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'`sample_weight` argument is not supported'):
|
||||
model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
|
||||
|
||||
|
||||
class TestTrainingUtils(test.TestCase):
|
||||
|
||||
def test_check_array_lengths(self):
|
||||
|
@ -30,7 +30,6 @@ import sys
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
@ -117,16 +116,16 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
|
||||
"""
|
||||
if archive_format is None:
|
||||
return False
|
||||
if archive_format is 'auto':
|
||||
if archive_format == 'auto':
|
||||
archive_format = ['tar', 'zip']
|
||||
if isinstance(archive_format, six.string_types):
|
||||
archive_format = [archive_format]
|
||||
|
||||
for archive_type in archive_format:
|
||||
if archive_type is 'tar':
|
||||
if archive_type == 'tar':
|
||||
open_fn = tarfile.open
|
||||
is_match_fn = tarfile.is_tarfile
|
||||
if archive_type is 'zip':
|
||||
if archive_type == 'zip':
|
||||
open_fn = zipfile.ZipFile
|
||||
is_match_fn = zipfile.is_zipfile
|
||||
|
||||
@ -237,7 +236,7 @@ def get_file(fname,
|
||||
|
||||
def dl_progress(count, block_size, total_size):
|
||||
if ProgressTracker.progbar is None:
|
||||
if total_size is -1:
|
||||
if total_size == -1:
|
||||
total_size = None
|
||||
ProgressTracker.progbar = Progbar(total_size)
|
||||
else:
|
||||
@ -288,7 +287,7 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
||||
Returns:
|
||||
The file hash
|
||||
"""
|
||||
if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
|
||||
if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
|
||||
hasher = hashlib.sha256()
|
||||
else:
|
||||
hasher = hashlib.md5()
|
||||
@ -314,8 +313,7 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
|
||||
Returns:
|
||||
Whether the file is valid
|
||||
"""
|
||||
if ((algorithm is 'sha256') or
|
||||
(algorithm is 'auto' and len(file_hash) is 64)):
|
||||
if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
|
||||
hasher = 'sha256'
|
||||
else:
|
||||
hasher = 'md5'
|
||||
@ -400,14 +398,23 @@ class Sequence(object):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates an infinite generator that iterate over the Sequence.
|
||||
"""Create a generator that iterate over the Sequence."""
|
||||
for item in (self[i] for i in range(len(self))):
|
||||
yield item
|
||||
|
||||
Yields:
|
||||
Sequence items.
|
||||
"""
|
||||
while True:
|
||||
for item in (self[i] for i in range(len(self))):
|
||||
yield item
|
||||
|
||||
def iter_sequence_infinite(seq):
|
||||
"""Iterates indefinitely over a Sequence.
|
||||
|
||||
Arguments:
|
||||
seq: Sequence instance.
|
||||
|
||||
Yields:
|
||||
Batches of data from the Sequence.
|
||||
"""
|
||||
while True:
|
||||
for item in seq:
|
||||
yield item
|
||||
|
||||
|
||||
# Global variables to be shared across processes
|
||||
@ -445,7 +452,7 @@ class SequenceEnqueuer(object):
|
||||
The task of an Enqueuer is to use parallelism to speed up preprocessing.
|
||||
This is done with processes or threads.
|
||||
|
||||
Examples:
|
||||
Example:
|
||||
|
||||
```python
|
||||
enqueuer = SequenceEnqueuer(...)
|
||||
@ -458,61 +465,10 @@ class SequenceEnqueuer(object):
|
||||
```
|
||||
|
||||
The `enqueuer.get()` should be an infinite stream of datas.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_running(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start(self, workers=1, max_queue_size=10):
|
||||
"""Starts the handler's workers.
|
||||
|
||||
Arguments:
|
||||
workers: number of worker threads
|
||||
max_queue_size: queue size
|
||||
(when full, threads could block on `put()`).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, timeout=None):
|
||||
"""Stop running threads and wait for them to exit, if necessary.
|
||||
|
||||
Should be called by the same thread which called start().
|
||||
|
||||
Arguments:
|
||||
timeout: maximum time to wait on thread.join()
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self):
|
||||
"""Creates a generator to extract data from the queue.
|
||||
|
||||
Skip the data if it is `None`.
|
||||
|
||||
Returns:
|
||||
Generator yielding tuples `(inputs, targets)`
|
||||
or `(inputs, targets, sample_weights)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@tf_export('keras.utils.OrderedEnqueuer')
|
||||
class OrderedEnqueuer(SequenceEnqueuer):
|
||||
"""Builds a Enqueuer from a Sequence.
|
||||
|
||||
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
|
||||
|
||||
Arguments:
|
||||
sequence: A `keras.utils.data_utils.Sequence` object.
|
||||
use_multiprocessing: use multiprocessing if True, otherwise threading
|
||||
shuffle: whether to shuffle the data at the beginning of each epoch
|
||||
"""
|
||||
|
||||
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
|
||||
def __init__(self, sequence,
|
||||
use_multiprocessing=False):
|
||||
self.sequence = sequence
|
||||
self.use_multiprocessing = use_multiprocessing
|
||||
|
||||
@ -535,7 +491,6 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
self.uid = _SEQUENCE_COUNTER.value
|
||||
_SEQUENCE_COUNTER.value += 1
|
||||
|
||||
self.shuffle = shuffle
|
||||
self.workers = 0
|
||||
self.executor_fn = None
|
||||
self.queue = None
|
||||
@ -546,16 +501,15 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
return self.stop_signal is not None and not self.stop_signal.is_set()
|
||||
|
||||
def start(self, workers=1, max_queue_size=10):
|
||||
"""Start the handler's workers.
|
||||
"""Starts the handler's workers.
|
||||
|
||||
Arguments:
|
||||
workers: number of worker threads
|
||||
workers: Number of workers.
|
||||
max_queue_size: queue size
|
||||
(when full, workers could block on `put()`)
|
||||
"""
|
||||
if self.use_multiprocessing:
|
||||
self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
|
||||
workers, initializer=init_pool, initargs=(seqs,))
|
||||
self.executor_fn = self._get_executor_init(workers)
|
||||
else:
|
||||
# We do not need the init since it's threads.
|
||||
self.executor_fn = lambda _: ThreadPool(workers)
|
||||
@ -566,6 +520,87 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
self.run_thread.daemon = True
|
||||
self.run_thread.start()
|
||||
|
||||
def _send_sequence(self):
|
||||
"""Sends current Iterable to all workers."""
|
||||
# For new processes that may spawn
|
||||
_SHARED_SEQUENCES[self.uid] = self.sequence
|
||||
|
||||
def stop(self, timeout=None):
|
||||
"""Stops running threads and wait for them to exit, if necessary.
|
||||
|
||||
Should be called by the same thread which called `start()`.
|
||||
|
||||
Arguments:
|
||||
timeout: maximum time to wait on `thread.join()`
|
||||
"""
|
||||
self.stop_signal.set()
|
||||
with self.queue.mutex:
|
||||
self.queue.queue.clear()
|
||||
self.queue.unfinished_tasks = 0
|
||||
self.queue.not_full.notify()
|
||||
self.run_thread.join(timeout)
|
||||
_SHARED_SEQUENCES[self.uid] = None
|
||||
|
||||
@abstractmethod
|
||||
def _run(self):
|
||||
"""Submits request to the executor and queue the `Future` objects."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_executor_init(self, workers):
|
||||
"""Gets the Pool initializer for multiprocessing.
|
||||
|
||||
Arguments:
|
||||
workers: Number of workers.
|
||||
|
||||
Returns:
|
||||
Function, a Function to initialize the pool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self):
|
||||
"""Creates a generator to extract data from the queue.
|
||||
|
||||
Skip the data if it is `None`.
|
||||
# Returns
|
||||
Generator yielding tuples `(inputs, targets)`
|
||||
or `(inputs, targets, sample_weights)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@tf_export('keras.utils.OrderedEnqueuer')
|
||||
class OrderedEnqueuer(SequenceEnqueuer):
|
||||
"""Builds a Enqueuer from a Sequence.
|
||||
|
||||
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
|
||||
|
||||
Arguments:
|
||||
sequence: A `tf.keras.utils.data_utils.Sequence` object.
|
||||
use_multiprocessing: use multiprocessing if True, otherwise threading
|
||||
shuffle: whether to shuffle the data at the beginning of each epoch
|
||||
"""
|
||||
|
||||
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
|
||||
super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
|
||||
self.shuffle = shuffle
|
||||
|
||||
def _get_executor_init(self, workers):
|
||||
"""Gets the Pool initializer for multiprocessing.
|
||||
|
||||
Arguments:
|
||||
workers: Number of workers.
|
||||
|
||||
Returns:
|
||||
Function, a Function to initialize the pool
|
||||
"""
|
||||
def pool_fn(seqs):
|
||||
return multiprocessing.Pool(workers,
|
||||
initializer=init_pool_generator,
|
||||
initargs=(seqs, self.random_seed))
|
||||
return pool_fn
|
||||
|
||||
def _wait_queue(self):
|
||||
"""Wait for the queue to be empty."""
|
||||
while True:
|
||||
@ -615,30 +650,34 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
self.queue.task_done()
|
||||
if inputs is not None:
|
||||
yield inputs
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.stop()
|
||||
six.raise_from(StopIteration(e), e)
|
||||
six.reraise(*sys.exc_info())
|
||||
|
||||
def _send_sequence(self):
|
||||
"""Send current Sequence to all workers."""
|
||||
# For new processes that may spawn
|
||||
_SHARED_SEQUENCES[self.uid] = self.sequence
|
||||
|
||||
def stop(self, timeout=None):
|
||||
"""Stops running threads and wait for them to exit, if necessary.
|
||||
def init_pool_generator(gens, random_seed=None):
|
||||
global _SHARED_SEQUENCES
|
||||
_SHARED_SEQUENCES = gens
|
||||
|
||||
Should be called by the same thread which called `start()`.
|
||||
if random_seed is not None:
|
||||
ident = multiprocessing.current_process().ident
|
||||
np.random.seed(random_seed + ident)
|
||||
|
||||
Arguments:
|
||||
timeout: maximum time to wait on `thread.join()`
|
||||
"""
|
||||
self.stop_signal.set()
|
||||
with self.queue.mutex:
|
||||
self.queue.queue.clear()
|
||||
self.queue.unfinished_tasks = 0
|
||||
self.queue.not_full.notify()
|
||||
self.run_thread.join(timeout)
|
||||
_SHARED_SEQUENCES[self.uid] = None
|
||||
|
||||
def next_sample(uid):
|
||||
"""Gets the next value from the generator `uid`.
|
||||
|
||||
To allow multiple generators to be used at the same time, we use `uid` to
|
||||
get a specific one. A single generator would cause the validation to
|
||||
overwrite the training generator.
|
||||
|
||||
Arguments:
|
||||
uid: int, generator identifier
|
||||
|
||||
Returns:
|
||||
The next value of generator `uid`.
|
||||
"""
|
||||
return six.next(_SHARED_SEQUENCES[uid])
|
||||
|
||||
|
||||
@tf_export('keras.utils.GeneratorEnqueuer')
|
||||
@ -658,145 +697,36 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
||||
will be incremented by one for each worker.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
generator,
|
||||
def __init__(self, sequence,
|
||||
use_multiprocessing=False,
|
||||
wait_time=0.05,
|
||||
seed=None):
|
||||
self.wait_time = wait_time
|
||||
self._generator = generator
|
||||
if os.name is 'nt' and use_multiprocessing is True:
|
||||
# On Windows, avoid **SYSTEMATIC** error in `multiprocessing`:
|
||||
# `TypeError: can't pickle generator objects`
|
||||
# => Suggest multithreading instead of multiprocessing on Windows
|
||||
raise ValueError('Using a generator with `use_multiprocessing=True`'
|
||||
' is not supported on Windows (no marshalling of'
|
||||
' generators across process boundaries). Instead,'
|
||||
' use single thread/process or multithreading.')
|
||||
else:
|
||||
self._use_multiprocessing = use_multiprocessing
|
||||
self._threads = []
|
||||
self._stop_event = None
|
||||
self._manager = None
|
||||
self.queue = None
|
||||
self.seed = seed
|
||||
random_seed=None):
|
||||
super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing)
|
||||
self.random_seed = random_seed
|
||||
|
||||
def _data_generator_task(self):
|
||||
if self._use_multiprocessing is False:
|
||||
while not self._stop_event.is_set():
|
||||
with self.genlock:
|
||||
try:
|
||||
if (self.queue is not None and
|
||||
self.queue.qsize() < self.max_queue_size):
|
||||
# On all OSes, avoid **SYSTEMATIC** error
|
||||
# in multithreading mode:
|
||||
# `ValueError: generator already executing`
|
||||
# => Serialize calls to
|
||||
# infinite iterator/generator's next() function
|
||||
generator_output = next(self._generator)
|
||||
self.queue.put((True, generator_output))
|
||||
else:
|
||||
time.sleep(self.wait_time)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Can't pickle tracebacks.
|
||||
# As a compromise, print the traceback and pickle None instead.
|
||||
if not hasattr(e, '__traceback__'):
|
||||
setattr(e, '__traceback__', sys.exc_info()[2])
|
||||
self.queue.put((False, e))
|
||||
self._stop_event.set()
|
||||
break
|
||||
else:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
if (self.queue is not None and
|
||||
self.queue.qsize() < self.max_queue_size):
|
||||
generator_output = next(self._generator)
|
||||
self.queue.put((True, generator_output))
|
||||
else:
|
||||
time.sleep(self.wait_time)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Can't pickle tracebacks.
|
||||
# As a compromise, print the traceback and pickle None instead.
|
||||
traceback.print_exc()
|
||||
setattr(e, '__traceback__', None)
|
||||
self.queue.put((False, e))
|
||||
self._stop_event.set()
|
||||
break
|
||||
|
||||
def start(self, workers=1, max_queue_size=10):
|
||||
"""Kicks off threads which add data from the generator into the queue.
|
||||
def _get_executor_init(self, workers):
|
||||
"""Gets the Pool initializer for multiprocessing.
|
||||
|
||||
Arguments:
|
||||
workers: number of worker threads
|
||||
max_queue_size: queue size
|
||||
(when full, threads could block on `put()`)
|
||||
workers: Number of works.
|
||||
|
||||
Returns:
|
||||
A Function to initialize the pool
|
||||
"""
|
||||
try:
|
||||
self.max_queue_size = max_queue_size
|
||||
if self._use_multiprocessing:
|
||||
self._manager = multiprocessing.Manager()
|
||||
self.queue = self._manager.Queue(maxsize=max_queue_size)
|
||||
self._stop_event = multiprocessing.Event()
|
||||
else:
|
||||
# On all OSes, avoid **SYSTEMATIC** error in multithreading mode:
|
||||
# `ValueError: generator already executing`
|
||||
# => Serialize calls to infinite iterator/generator's next() function
|
||||
self.genlock = threading.Lock()
|
||||
self.queue = queue.Queue(maxsize=max_queue_size)
|
||||
self._stop_event = threading.Event()
|
||||
def pool_fn(seqs):
|
||||
return multiprocessing.Pool(workers,
|
||||
initializer=init_pool_generator,
|
||||
initargs=(seqs, self.random_seed))
|
||||
return pool_fn
|
||||
|
||||
for _ in range(workers):
|
||||
if self._use_multiprocessing:
|
||||
# Reset random seed else all children processes
|
||||
# share the same seed
|
||||
np.random.seed(self.seed)
|
||||
thread = multiprocessing.Process(target=self._data_generator_task)
|
||||
thread.daemon = True
|
||||
if self.seed is not None:
|
||||
self.seed += 1
|
||||
else:
|
||||
thread = threading.Thread(target=self._data_generator_task)
|
||||
self._threads.append(thread)
|
||||
thread.start()
|
||||
except:
|
||||
self.stop()
|
||||
raise
|
||||
|
||||
def is_running(self):
|
||||
return self._stop_event is not None and not self._stop_event.is_set()
|
||||
|
||||
def stop(self, timeout=None):
|
||||
"""Stops running threads and wait for them to exit, if necessary.
|
||||
|
||||
Should be called by the same thread which called `start()`.
|
||||
|
||||
Arguments:
|
||||
timeout: maximum time to wait on `thread.join()`.
|
||||
"""
|
||||
if self.is_running():
|
||||
self._stop_event.set()
|
||||
|
||||
for thread in self._threads:
|
||||
if self._use_multiprocessing:
|
||||
if thread.is_alive():
|
||||
thread.terminate()
|
||||
else:
|
||||
# The thread.is_alive() test is subject to a race condition:
|
||||
# the thread could terminate right after the test and before the
|
||||
# join, rendering this test meaningless -> Call thread.join()
|
||||
# always, which is ok no matter what the status of the thread.
|
||||
thread.join(timeout)
|
||||
|
||||
if self._manager:
|
||||
self._manager.shutdown()
|
||||
|
||||
self._threads = []
|
||||
self._stop_event = None
|
||||
self.queue = None
|
||||
def _run(self):
|
||||
"""Submits request to the executor and queue the `Future` objects."""
|
||||
self._send_sequence() # Share the initial generator
|
||||
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
|
||||
while True:
|
||||
if self.stop_signal.is_set():
|
||||
return
|
||||
self.queue.put(
|
||||
executor.apply_async(next_sample, (self.uid,)), block=True)
|
||||
|
||||
def get(self):
|
||||
"""Creates a generator to extract data from the queue.
|
||||
@ -808,24 +738,30 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
||||
`(inputs, targets)` or
|
||||
`(inputs, targets, sample_weights)`.
|
||||
"""
|
||||
while self.is_running():
|
||||
if not self.queue.empty():
|
||||
success, value = self.queue.get()
|
||||
# Rethrow any exceptions found in the queue
|
||||
if not success:
|
||||
six.reraise(value.__class__, value, value.__traceback__)
|
||||
# Yield regular values
|
||||
if value is not None:
|
||||
yield value
|
||||
else:
|
||||
all_finished = all([not thread.is_alive() for thread in self._threads])
|
||||
if all_finished and self.queue.empty():
|
||||
raise StopIteration()
|
||||
else:
|
||||
time.sleep(self.wait_time)
|
||||
|
||||
# Make sure to rethrow the first exception in the queue, if any
|
||||
while not self.queue.empty():
|
||||
success, value = self.queue.get()
|
||||
if not success:
|
||||
six.reraise(value.__class__, value, value.__traceback__)
|
||||
try:
|
||||
while self.is_running():
|
||||
inputs = self.queue.get(block=True).get()
|
||||
self.queue.task_done()
|
||||
if inputs is not None:
|
||||
yield inputs
|
||||
except StopIteration:
|
||||
# Special case for finite generators
|
||||
last_ones = []
|
||||
while self.queue.qsize() > 0:
|
||||
last_ones.append(self.queue.get(block=True))
|
||||
# Wait for them to complete
|
||||
for f in last_ones:
|
||||
f.wait()
|
||||
# Keep the good ones
|
||||
last_ones = [future.get() for future in last_ones if future.successful()]
|
||||
for inputs in last_ones:
|
||||
if inputs is not None:
|
||||
yield inputs
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
self.stop()
|
||||
if 'generator already executing' in str(e):
|
||||
raise RuntimeError(
|
||||
'Your generator is NOT thread-safe. '
|
||||
'Keras requires a thread-safe generator when '
|
||||
'`use_multiprocessing=False, workers > 1`. ')
|
||||
six.reraise(*sys.exc_info())
|
||||
|
@ -228,7 +228,7 @@ class TestEnqueuers(test.TestCase):
|
||||
FaultSequence(), use_multiprocessing=False)
|
||||
enqueuer.start(3, 10)
|
||||
gen_output = enqueuer.get()
|
||||
with self.assertRaises(StopIteration):
|
||||
with self.assertRaises(IndexError):
|
||||
next(gen_output)
|
||||
|
||||
def test_ordered_enqueuer_fail_processes(self):
|
||||
@ -236,7 +236,7 @@ class TestEnqueuers(test.TestCase):
|
||||
FaultSequence(), use_multiprocessing=True)
|
||||
enqueuer.start(3, 10)
|
||||
gen_output = enqueuer.get()
|
||||
with self.assertRaises(StopIteration):
|
||||
with self.assertRaises(IndexError):
|
||||
next(gen_output)
|
||||
|
||||
def test_on_epoch_end_processes(self):
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get"
|
||||
|
@ -4,6 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get"
|
||||
|
@ -4,6 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get"
|
||||
|
@ -104,7 +104,8 @@ do_pylint() {
|
||||
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
|
||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
||||
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable"
|
||||
|
||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user