Part 3/4 of the update of tf.keras to the 2.2.4 API.

PiperOrigin-RevId: 216639755
This commit is contained in:
Francois Chollet 2018-10-10 20:50:21 -07:00 committed by TensorFlower Gardener
parent 2b8f59243e
commit 96a633367e
11 changed files with 518 additions and 539 deletions

View File

@ -718,6 +718,19 @@ py_test(
],
)
py_test(
name = "training_generator_test",
size = "enormous",
srcs = ["engine/training_generator_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
":keras",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
],
)
py_test(
name = "feature_columns_integration_test",
size = "small",

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import iter_sequence_infinite
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar
@ -45,7 +46,6 @@ def fit_generator(model,
shuffle=True,
initial_epoch=0):
"""See docstring for `Model.fit_generator`."""
wait_time = 0.01 # in seconds
epoch = initial_epoch
do_validation = bool(validation_data)
@ -124,13 +124,12 @@ def fit_generator(model,
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
use_multiprocessing=use_multiprocessing)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
@ -251,7 +250,6 @@ def evaluate_generator(model,
stateful_metric_indices = []
steps_done = 0
wait_time = 0.01
all_outs = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
@ -279,13 +277,12 @@ def evaluate_generator(model,
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
use_multiprocessing=use_multiprocessing)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
@ -354,7 +351,6 @@ def predict_generator(model,
model._make_test_function()
steps_done = 0
wait_time = 0.01
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@ -381,13 +377,12 @@ def predict_generator(model,
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
use_multiprocessing=use_multiprocessing)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator

View File

@ -0,0 +1,307 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for training routines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
class TestGeneratorMethods(test.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
def test_generator_methods(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
yield x, y
with self.cached_session():
x = keras.Input((2,))
y = keras.layers.Dense(1)(x)
fn_model = keras.models.Model(x, y)
fn_model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
seq_model = keras.models.Sequential()
seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
seq_model.compile(loss='mse', optimizer='sgd')
for model in [fn_model, seq_model]:
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
workers=4,
use_multiprocessing=True)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
validation_data=custom_generator(),
validation_steps=1,
workers=0)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=2,
use_multiprocessing=True)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=0)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=2,
verbose=1,
use_multiprocessing=True)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False,
workers=0)
def test_generator_methods_with_sample_weights(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
arr_sample_weights = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
w = arr_sample_weights[start: end]
yield x, y, w
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
def test_generator_methods_invalid_use_case(self):
def custom_generator():
while 1:
yield 0
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
with self.assertRaises(AttributeError):
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaises(ValueError):
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
def test_training_with_sequences(self):
class DummySequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10])
def __len__(self):
return 10
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
arr_sample_weights = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
w = arr_sample_weights[start: end]
yield x, y, w
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
model.fit_generator(DummySequence(),
steps_per_epoch=10,
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=True)
model.fit_generator(DummySequence(),
steps_per_epoch=10,
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=False)
@tf_test_util.run_in_graph_and_eager_modes
def test_generator_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
def custom_generator():
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.fit(
custom_generator(),
steps_per_epoch=2,
validation_data=val_data,
epochs=2)
model.evaluate(custom_generator(), steps=2)
model.predict(custom_generator(), steps=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_sequence_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
class CustomSequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
def __len__(self):
return 2
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.fit(CustomSequence(), validation_data=val_data, epochs=2)
model.evaluate(CustomSequence())
model.predict(CustomSequence())
with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
model.fit(CustomSequence(), y=np.ones([10, 1]))
with self.assertRaisesRegexp(ValueError,
'`sample_weight` argument is not supported'):
model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
if __name__ == '__main__':
test.main()

View File

@ -19,8 +19,6 @@ from __future__ import division
from __future__ import print_function
import logging
import os
import unittest
import numpy as np
@ -1102,279 +1100,6 @@ class TestDynamicTrainability(test.TestCase):
self.assertListEqual(outer_model.trainable_weights, [])
class TestGeneratorMethods(test.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
def test_generator_methods(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
yield x, y
with self.cached_session():
x = keras.Input((2,))
y = keras.layers.Dense(1)(x)
fn_model = keras.models.Model(x, y)
fn_model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
seq_model = keras.models.Sequential()
seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
seq_model.compile(loss='mse', optimizer='sgd')
for model in [fn_model, seq_model]:
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
workers=4,
use_multiprocessing=True)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
validation_data=custom_generator(),
validation_steps=1,
workers=0)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=2,
use_multiprocessing=True)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=0)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
workers=2,
verbose=1,
use_multiprocessing=True)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False,
workers=0)
def test_generator_methods_with_sample_weights(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
arr_sample_weights = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
w = arr_sample_weights[start: end]
yield x, y, w
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
def test_generator_methods_invalid_use_case(self):
def custom_generator():
while 1:
yield 0
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
with self.assertRaises(AttributeError):
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
with self.assertRaises(ValueError):
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
def test_training_with_sequences(self):
class DummySequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10])
def __len__(self):
return 10
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
arr_sample_weights = np.random.random((50,))
def custom_generator():
batch_size = 10
num_samples = 50
while True:
batch_index = np.random.randint(0, num_samples - batch_size)
start = batch_index
end = start + batch_size
x = arr_data[start: end]
y = arr_labels[start: end]
w = arr_sample_weights[start: end]
yield x, y, w
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
model.fit_generator(DummySequence(),
steps_per_epoch=10,
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=True)
model.fit_generator(DummySequence(),
steps_per_epoch=10,
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=False)
@tf_test_util.run_in_graph_and_eager_modes
def test_generator_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
def custom_generator():
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.fit(
custom_generator(),
steps_per_epoch=2,
validation_data=val_data,
epochs=2)
model.evaluate(custom_generator(), steps=2)
model.predict(custom_generator(), steps=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_sequence_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
class CustomSequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
def __len__(self):
return 2
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.fit(CustomSequence(), validation_data=val_data, epochs=2)
model.evaluate(CustomSequence())
model.predict(CustomSequence())
with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
model.fit(CustomSequence(), y=np.ones([10, 1]))
with self.assertRaisesRegexp(ValueError,
'`sample_weight` argument is not supported'):
model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
class TestTrainingUtils(test.TestCase):
def test_check_array_lengths(self):

View File

@ -30,7 +30,6 @@ import sys
import tarfile
import threading
import time
import traceback
import zipfile
import numpy as np
@ -117,16 +116,16 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
"""
if archive_format is None:
return False
if archive_format is 'auto':
if archive_format == 'auto':
archive_format = ['tar', 'zip']
if isinstance(archive_format, six.string_types):
archive_format = [archive_format]
for archive_type in archive_format:
if archive_type is 'tar':
if archive_type == 'tar':
open_fn = tarfile.open
is_match_fn = tarfile.is_tarfile
if archive_type is 'zip':
if archive_type == 'zip':
open_fn = zipfile.ZipFile
is_match_fn = zipfile.is_zipfile
@ -237,7 +236,7 @@ def get_file(fname,
def dl_progress(count, block_size, total_size):
if ProgressTracker.progbar is None:
if total_size is -1:
if total_size == -1:
total_size = None
ProgressTracker.progbar = Progbar(total_size)
else:
@ -288,7 +287,7 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
Returns:
The file hash
"""
if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64):
if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
hasher = hashlib.sha256()
else:
hasher = hashlib.md5()
@ -314,8 +313,7 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
Returns:
Whether the file is valid
"""
if ((algorithm is 'sha256') or
(algorithm is 'auto' and len(file_hash) is 64)):
if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
hasher = 'sha256'
else:
hasher = 'md5'
@ -400,14 +398,23 @@ class Sequence(object):
pass
def __iter__(self):
"""Creates an infinite generator that iterate over the Sequence.
"""Create a generator that iterate over the Sequence."""
for item in (self[i] for i in range(len(self))):
yield item
Yields:
Sequence items.
"""
while True:
for item in (self[i] for i in range(len(self))):
yield item
def iter_sequence_infinite(seq):
"""Iterates indefinitely over a Sequence.
Arguments:
seq: Sequence instance.
Yields:
Batches of data from the Sequence.
"""
while True:
for item in seq:
yield item
# Global variables to be shared across processes
@ -445,7 +452,7 @@ class SequenceEnqueuer(object):
The task of an Enqueuer is to use parallelism to speed up preprocessing.
This is done with processes or threads.
Examples:
Example:
```python
enqueuer = SequenceEnqueuer(...)
@ -458,61 +465,10 @@ class SequenceEnqueuer(object):
```
The `enqueuer.get()` should be an infinite stream of datas.
"""
@abstractmethod
def is_running(self):
raise NotImplementedError
@abstractmethod
def start(self, workers=1, max_queue_size=10):
"""Starts the handler's workers.
Arguments:
workers: number of worker threads
max_queue_size: queue size
(when full, threads could block on `put()`).
"""
raise NotImplementedError
@abstractmethod
def stop(self, timeout=None):
"""Stop running threads and wait for them to exit, if necessary.
Should be called by the same thread which called start().
Arguments:
timeout: maximum time to wait on thread.join()
"""
raise NotImplementedError
@abstractmethod
def get(self):
"""Creates a generator to extract data from the queue.
Skip the data if it is `None`.
Returns:
Generator yielding tuples `(inputs, targets)`
or `(inputs, targets, sample_weights)`.
"""
raise NotImplementedError
@tf_export('keras.utils.OrderedEnqueuer')
class OrderedEnqueuer(SequenceEnqueuer):
"""Builds a Enqueuer from a Sequence.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
Arguments:
sequence: A `keras.utils.data_utils.Sequence` object.
use_multiprocessing: use multiprocessing if True, otherwise threading
shuffle: whether to shuffle the data at the beginning of each epoch
"""
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
def __init__(self, sequence,
use_multiprocessing=False):
self.sequence = sequence
self.use_multiprocessing = use_multiprocessing
@ -535,7 +491,6 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.uid = _SEQUENCE_COUNTER.value
_SEQUENCE_COUNTER.value += 1
self.shuffle = shuffle
self.workers = 0
self.executor_fn = None
self.queue = None
@ -546,16 +501,15 @@ class OrderedEnqueuer(SequenceEnqueuer):
return self.stop_signal is not None and not self.stop_signal.is_set()
def start(self, workers=1, max_queue_size=10):
"""Start the handler's workers.
"""Starts the handler's workers.
Arguments:
workers: number of worker threads
workers: Number of workers.
max_queue_size: queue size
(when full, workers could block on `put()`)
"""
if self.use_multiprocessing:
self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
workers, initializer=init_pool, initargs=(seqs,))
self.executor_fn = self._get_executor_init(workers)
else:
# We do not need the init since it's threads.
self.executor_fn = lambda _: ThreadPool(workers)
@ -566,6 +520,87 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.run_thread.daemon = True
self.run_thread.start()
def _send_sequence(self):
"""Sends current Iterable to all workers."""
# For new processes that may spawn
_SHARED_SEQUENCES[self.uid] = self.sequence
def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Arguments:
timeout: maximum time to wait on `thread.join()`
"""
self.stop_signal.set()
with self.queue.mutex:
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
self.run_thread.join(timeout)
_SHARED_SEQUENCES[self.uid] = None
@abstractmethod
def _run(self):
"""Submits request to the executor and queue the `Future` objects."""
raise NotImplementedError
@abstractmethod
def _get_executor_init(self, workers):
"""Gets the Pool initializer for multiprocessing.
Arguments:
workers: Number of workers.
Returns:
Function, a Function to initialize the pool
"""
raise NotImplementedError
@abstractmethod
def get(self):
"""Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Returns
Generator yielding tuples `(inputs, targets)`
or `(inputs, targets, sample_weights)`.
"""
raise NotImplementedError
@tf_export('keras.utils.OrderedEnqueuer')
class OrderedEnqueuer(SequenceEnqueuer):
"""Builds a Enqueuer from a Sequence.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
Arguments:
sequence: A `tf.keras.utils.data_utils.Sequence` object.
use_multiprocessing: use multiprocessing if True, otherwise threading
shuffle: whether to shuffle the data at the beginning of each epoch
"""
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
self.shuffle = shuffle
def _get_executor_init(self, workers):
"""Gets the Pool initializer for multiprocessing.
Arguments:
workers: Number of workers.
Returns:
Function, a Function to initialize the pool
"""
def pool_fn(seqs):
return multiprocessing.Pool(workers,
initializer=init_pool_generator,
initargs=(seqs, self.random_seed))
return pool_fn
def _wait_queue(self):
"""Wait for the queue to be empty."""
while True:
@ -615,30 +650,34 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.queue.task_done()
if inputs is not None:
yield inputs
except Exception as e: # pylint: disable=broad-except
except Exception: # pylint: disable=broad-except
self.stop()
six.raise_from(StopIteration(e), e)
six.reraise(*sys.exc_info())
def _send_sequence(self):
"""Send current Sequence to all workers."""
# For new processes that may spawn
_SHARED_SEQUENCES[self.uid] = self.sequence
def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
def init_pool_generator(gens, random_seed=None):
global _SHARED_SEQUENCES
_SHARED_SEQUENCES = gens
Should be called by the same thread which called `start()`.
if random_seed is not None:
ident = multiprocessing.current_process().ident
np.random.seed(random_seed + ident)
Arguments:
timeout: maximum time to wait on `thread.join()`
"""
self.stop_signal.set()
with self.queue.mutex:
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
self.run_thread.join(timeout)
_SHARED_SEQUENCES[self.uid] = None
def next_sample(uid):
"""Gets the next value from the generator `uid`.
To allow multiple generators to be used at the same time, we use `uid` to
get a specific one. A single generator would cause the validation to
overwrite the training generator.
Arguments:
uid: int, generator identifier
Returns:
The next value of generator `uid`.
"""
return six.next(_SHARED_SEQUENCES[uid])
@tf_export('keras.utils.GeneratorEnqueuer')
@ -658,145 +697,36 @@ class GeneratorEnqueuer(SequenceEnqueuer):
will be incremented by one for each worker.
"""
def __init__(self,
generator,
def __init__(self, sequence,
use_multiprocessing=False,
wait_time=0.05,
seed=None):
self.wait_time = wait_time
self._generator = generator
if os.name is 'nt' and use_multiprocessing is True:
# On Windows, avoid **SYSTEMATIC** error in `multiprocessing`:
# `TypeError: can't pickle generator objects`
# => Suggest multithreading instead of multiprocessing on Windows
raise ValueError('Using a generator with `use_multiprocessing=True`'
' is not supported on Windows (no marshalling of'
' generators across process boundaries). Instead,'
' use single thread/process or multithreading.')
else:
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self._manager = None
self.queue = None
self.seed = seed
random_seed=None):
super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing)
self.random_seed = random_seed
def _data_generator_task(self):
if self._use_multiprocessing is False:
while not self._stop_event.is_set():
with self.genlock:
try:
if (self.queue is not None and
self.queue.qsize() < self.max_queue_size):
# On all OSes, avoid **SYSTEMATIC** error
# in multithreading mode:
# `ValueError: generator already executing`
# => Serialize calls to
# infinite iterator/generator's next() function
generator_output = next(self._generator)
self.queue.put((True, generator_output))
else:
time.sleep(self.wait_time)
except StopIteration:
break
except Exception as e: # pylint: disable=broad-except
# Can't pickle tracebacks.
# As a compromise, print the traceback and pickle None instead.
if not hasattr(e, '__traceback__'):
setattr(e, '__traceback__', sys.exc_info()[2])
self.queue.put((False, e))
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
if (self.queue is not None and
self.queue.qsize() < self.max_queue_size):
generator_output = next(self._generator)
self.queue.put((True, generator_output))
else:
time.sleep(self.wait_time)
except StopIteration:
break
except Exception as e: # pylint: disable=broad-except
# Can't pickle tracebacks.
# As a compromise, print the traceback and pickle None instead.
traceback.print_exc()
setattr(e, '__traceback__', None)
self.queue.put((False, e))
self._stop_event.set()
break
def start(self, workers=1, max_queue_size=10):
"""Kicks off threads which add data from the generator into the queue.
def _get_executor_init(self, workers):
"""Gets the Pool initializer for multiprocessing.
Arguments:
workers: number of worker threads
max_queue_size: queue size
(when full, threads could block on `put()`)
workers: Number of works.
Returns:
A Function to initialize the pool
"""
try:
self.max_queue_size = max_queue_size
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
# On all OSes, avoid **SYSTEMATIC** error in multithreading mode:
# `ValueError: generator already executing`
# => Serialize calls to infinite iterator/generator's next() function
self.genlock = threading.Lock()
self.queue = queue.Queue(maxsize=max_queue_size)
self._stop_event = threading.Event()
def pool_fn(seqs):
return multiprocessing.Pool(workers,
initializer=init_pool_generator,
initargs=(seqs, self.random_seed))
return pool_fn
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=self._data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=self._data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Arguments:
timeout: maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.terminate()
else:
# The thread.is_alive() test is subject to a race condition:
# the thread could terminate right after the test and before the
# join, rendering this test meaningless -> Call thread.join()
# always, which is ok no matter what the status of the thread.
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def _run(self):
"""Submits request to the executor and queue the `Future` objects."""
self._send_sequence() # Share the initial generator
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
while True:
if self.stop_signal.is_set():
return
self.queue.put(
executor.apply_async(next_sample, (self.uid,)), block=True)
def get(self):
"""Creates a generator to extract data from the queue.
@ -808,24 +738,30 @@ class GeneratorEnqueuer(SequenceEnqueuer):
`(inputs, targets)` or
`(inputs, targets, sample_weights)`.
"""
while self.is_running():
if not self.queue.empty():
success, value = self.queue.get()
# Rethrow any exceptions found in the queue
if not success:
six.reraise(value.__class__, value, value.__traceback__)
# Yield regular values
if value is not None:
yield value
else:
all_finished = all([not thread.is_alive() for thread in self._threads])
if all_finished and self.queue.empty():
raise StopIteration()
else:
time.sleep(self.wait_time)
# Make sure to rethrow the first exception in the queue, if any
while not self.queue.empty():
success, value = self.queue.get()
if not success:
six.reraise(value.__class__, value, value.__traceback__)
try:
while self.is_running():
inputs = self.queue.get(block=True).get()
self.queue.task_done()
if inputs is not None:
yield inputs
except StopIteration:
# Special case for finite generators
last_ones = []
while self.queue.qsize() > 0:
last_ones.append(self.queue.get(block=True))
# Wait for them to complete
for f in last_ones:
f.wait()
# Keep the good ones
last_ones = [future.get() for future in last_ones if future.successful()]
for inputs in last_ones:
if inputs is not None:
yield inputs
except Exception as e: # pylint: disable=broad-except
self.stop()
if 'generator already executing' in str(e):
raise RuntimeError(
'Your generator is NOT thread-safe. '
'Keras requires a thread-safe generator when '
'`use_multiprocessing=False, workers > 1`. ')
six.reraise(*sys.exc_info())

View File

@ -228,7 +228,7 @@ class TestEnqueuers(test.TestCase):
FaultSequence(), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with self.assertRaises(StopIteration):
with self.assertRaises(IndexError):
next(gen_output)
def test_ordered_enqueuer_fail_processes(self):
@ -236,7 +236,7 @@ class TestEnqueuers(test.TestCase):
FaultSequence(), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with self.assertRaises(StopIteration):
with self.assertRaises(IndexError):
next(gen_output)
def test_on_epoch_end_processes(self):

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], "
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "get"

View File

@ -4,6 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "get"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], "
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "get"

View File

@ -4,6 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "get"

View File

@ -104,7 +104,8 @@ do_pylint() {
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable"
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""