From 96a633367ecd5ae9b31e128c2436b1a3f81b27fd Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 10 Oct 2018 20:50:21 -0700 Subject: [PATCH] Part 3/4 of the update of tf.keras to the 2.2.4 API. PiperOrigin-RevId: 216639755 --- tensorflow/python/keras/BUILD | 13 + .../python/keras/engine/training_generator.py | 19 +- .../keras/engine/training_generator_test.py | 307 +++++++++++++ .../python/keras/engine/training_test.py | 275 ----------- tensorflow/python/keras/utils/data_utils.py | 430 ++++++++---------- .../python/keras/utils/data_utils_test.py | 4 +- ...flow.keras.utils.-generator-enqueuer.pbtxt | 2 +- ...rflow.keras.utils.-sequence-enqueuer.pbtxt | 1 + ...flow.keras.utils.-generator-enqueuer.pbtxt | 2 +- ...rflow.keras.utils.-sequence-enqueuer.pbtxt | 1 + tensorflow/tools/ci_build/ci_sanity.sh | 3 +- 11 files changed, 518 insertions(+), 539 deletions(-) create mode 100644 tensorflow/python/keras/engine/training_generator_test.py diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index c4d23f117f8..a566c9acaba 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -718,6 +718,19 @@ py_test( ], ) +py_test( + name = "training_generator_test", + size = "enormous", + srcs = ["engine/training_generator_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "feature_columns_integration_test", size = "small", diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index 2e074699da8..21f44423ec0 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import iter_sequence_infinite from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar @@ -45,7 +46,6 @@ def fit_generator(model, shuffle=True, initial_epoch=0): """See docstring for `Model.fit_generator`.""" - wait_time = 0.01 # in seconds epoch = initial_epoch do_validation = bool(validation_data) @@ -124,13 +124,12 @@ def fit_generator(model, else: enqueuer = GeneratorEnqueuer( generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) + use_multiprocessing=use_multiprocessing) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator @@ -251,7 +250,6 @@ def evaluate_generator(model, stateful_metric_indices = [] steps_done = 0 - wait_time = 0.01 all_outs = [] batch_sizes = [] is_sequence = isinstance(generator, Sequence) @@ -279,13 +277,12 @@ def evaluate_generator(model, else: enqueuer = GeneratorEnqueuer( generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) + use_multiprocessing=use_multiprocessing) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator @@ -354,7 +351,6 @@ def predict_generator(model, model._make_test_function() steps_done = 0 - wait_time = 0.01 all_outs = [] is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: @@ -381,13 +377,12 @@ def predict_generator(model, else: enqueuer = GeneratorEnqueuer( generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) + use_multiprocessing=use_multiprocessing) enqueuer.start(workers=workers, max_queue_size=max_queue_size) output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py new file mode 100644 index 00000000000..88e89434242 --- /dev/null +++ b/tensorflow/python/keras/engine/training_generator_test.py @@ -0,0 +1,307 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training routines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + + +class TestGeneratorMethods(test.TestCase): + + @unittest.skipIf( + os.name == 'nt', + 'use_multiprocessing=True does not work on windows properly.') + def test_generator_methods(self): + arr_data = np.random.random((50, 2)) + arr_labels = np.random.random((50,)) + + def custom_generator(): + batch_size = 10 + num_samples = 50 + while True: + batch_index = np.random.randint(0, num_samples - batch_size) + start = batch_index + end = start + batch_size + x = arr_data[start: end] + y = arr_labels[start: end] + yield x, y + + with self.cached_session(): + x = keras.Input((2,)) + y = keras.layers.Dense(1)(x) + fn_model = keras.models.Model(x, y) + fn_model.compile( + loss='mse', + optimizer='sgd', + metrics=['mae', metrics_module.CategoricalAccuracy()]) + + seq_model = keras.models.Sequential() + seq_model.add(keras.layers.Dense(1, input_shape=(2,))) + seq_model.compile(loss='mse', optimizer='sgd') + + for model in [fn_model, seq_model]: + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + workers=4, + use_multiprocessing=True) + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False) + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False, + validation_data=custom_generator(), + validation_steps=10) + model.fit_generator(custom_generator(), + steps_per_epoch=5, + validation_data=custom_generator(), + validation_steps=1, + workers=0) + model.predict_generator(custom_generator(), + steps=5, + max_queue_size=10, + workers=2, + use_multiprocessing=True) + model.predict_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + model.predict_generator(custom_generator(), + steps=5, + max_queue_size=10, + workers=0) + model.evaluate_generator(custom_generator(), + steps=5, + max_queue_size=10, + workers=2, + verbose=1, + use_multiprocessing=True) + model.evaluate_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + model.evaluate_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False, + workers=0) + + def test_generator_methods_with_sample_weights(self): + arr_data = np.random.random((50, 2)) + arr_labels = np.random.random((50,)) + arr_sample_weights = np.random.random((50,)) + + def custom_generator(): + batch_size = 10 + num_samples = 50 + while True: + batch_index = np.random.randint(0, num_samples - batch_size) + start = batch_index + end = start + batch_size + x = arr_data[start: end] + y = arr_labels[start: end] + w = arr_sample_weights[start: end] + yield x, y, w + + with self.cached_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(2,))) + model.compile( + loss='mse', + optimizer='sgd', + metrics=['mae', metrics_module.CategoricalAccuracy()]) + + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False) + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False, + validation_data=custom_generator(), + validation_steps=10) + model.predict_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + model.evaluate_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + + def test_generator_methods_invalid_use_case(self): + + def custom_generator(): + while 1: + yield 0 + + with self.cached_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(2,))) + model.compile(loss='mse', optimizer='sgd') + + with self.assertRaises(ValueError): + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False) + with self.assertRaises(ValueError): + model.fit_generator(custom_generator(), + steps_per_epoch=5, + epochs=1, + verbose=1, + max_queue_size=10, + use_multiprocessing=False, + validation_data=custom_generator(), + validation_steps=10) + with self.assertRaises(AttributeError): + model.predict_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + with self.assertRaises(ValueError): + model.evaluate_generator(custom_generator(), + steps=5, + max_queue_size=10, + use_multiprocessing=False) + + def test_training_with_sequences(self): + + class DummySequence(keras.utils.Sequence): + + def __getitem__(self, idx): + return np.zeros([10, 2]), np.ones([10]) + + def __len__(self): + return 10 + + arr_data = np.random.random((50, 2)) + arr_labels = np.random.random((50,)) + arr_sample_weights = np.random.random((50,)) + + def custom_generator(): + batch_size = 10 + num_samples = 50 + while True: + batch_index = np.random.randint(0, num_samples - batch_size) + start = batch_index + end = start + batch_size + x = arr_data[start: end] + y = arr_labels[start: end] + w = arr_sample_weights[start: end] + yield x, y, w + + with self.cached_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(2,))) + model.compile(loss='mse', optimizer='sgd') + + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=True) + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=False) + + @tf_test_util.run_in_graph_and_eager_modes + def test_generator_input_to_fit_eval_predict(self): + val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + def custom_generator(): + while True: + yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + inputs = keras.layers.Input(shape=(10,)) + x = keras.layers.Dense(10, activation='relu')(inputs) + outputs = keras.layers.Dense(1, activation='sigmoid')(x) + model = keras.Model(inputs, outputs) + + model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') + model.fit( + custom_generator(), + steps_per_epoch=2, + validation_data=val_data, + epochs=2) + model.evaluate(custom_generator(), steps=2) + model.predict(custom_generator(), steps=2) + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequence_input_to_fit_eval_predict(self): + val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + class CustomSequence(keras.utils.Sequence): + + def __getitem__(self, idx): + return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + def __len__(self): + return 2 + + inputs = keras.layers.Input(shape=(10,)) + x = keras.layers.Dense(10, activation='relu')(inputs) + outputs = keras.layers.Dense(1, activation='sigmoid')(x) + model = keras.Model(inputs, outputs) + + model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') + model.fit(CustomSequence(), validation_data=val_data, epochs=2) + model.evaluate(CustomSequence()) + model.predict(CustomSequence()) + + with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'): + model.fit(CustomSequence(), y=np.ones([10, 1])) + + with self.assertRaisesRegexp(ValueError, + '`sample_weight` argument is not supported'): + model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 868fd1dc696..bd6b0e1aa14 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function import logging -import os -import unittest import numpy as np @@ -1102,279 +1100,6 @@ class TestDynamicTrainability(test.TestCase): self.assertListEqual(outer_model.trainable_weights, []) -class TestGeneratorMethods(test.TestCase): - - @unittest.skipIf( - os.name == 'nt', - 'use_multiprocessing=True does not work on windows properly.') - def test_generator_methods(self): - arr_data = np.random.random((50, 2)) - arr_labels = np.random.random((50,)) - - def custom_generator(): - batch_size = 10 - num_samples = 50 - while True: - batch_index = np.random.randint(0, num_samples - batch_size) - start = batch_index - end = start + batch_size - x = arr_data[start: end] - y = arr_labels[start: end] - yield x, y - - with self.cached_session(): - x = keras.Input((2,)) - y = keras.layers.Dense(1)(x) - fn_model = keras.models.Model(x, y) - fn_model.compile( - loss='mse', - optimizer='sgd', - metrics=['mae', metrics_module.CategoricalAccuracy()]) - - seq_model = keras.models.Sequential() - seq_model.add(keras.layers.Dense(1, input_shape=(2,))) - seq_model.compile(loss='mse', optimizer='sgd') - - for model in [fn_model, seq_model]: - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - workers=4, - use_multiprocessing=True) - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False) - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False, - validation_data=custom_generator(), - validation_steps=10) - model.fit_generator(custom_generator(), - steps_per_epoch=5, - validation_data=custom_generator(), - validation_steps=1, - workers=0) - model.predict_generator(custom_generator(), - steps=5, - max_queue_size=10, - workers=2, - use_multiprocessing=True) - model.predict_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - model.predict_generator(custom_generator(), - steps=5, - max_queue_size=10, - workers=0) - model.evaluate_generator(custom_generator(), - steps=5, - max_queue_size=10, - workers=2, - verbose=1, - use_multiprocessing=True) - model.evaluate_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - model.evaluate_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False, - workers=0) - - def test_generator_methods_with_sample_weights(self): - arr_data = np.random.random((50, 2)) - arr_labels = np.random.random((50,)) - arr_sample_weights = np.random.random((50,)) - - def custom_generator(): - batch_size = 10 - num_samples = 50 - while True: - batch_index = np.random.randint(0, num_samples - batch_size) - start = batch_index - end = start + batch_size - x = arr_data[start: end] - y = arr_labels[start: end] - w = arr_sample_weights[start: end] - yield x, y, w - - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_shape=(2,))) - model.compile( - loss='mse', - optimizer='sgd', - metrics=['mae', metrics_module.CategoricalAccuracy()]) - - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False) - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False, - validation_data=custom_generator(), - validation_steps=10) - model.predict_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - model.evaluate_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - - def test_generator_methods_invalid_use_case(self): - - def custom_generator(): - while 1: - yield 0 - - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_shape=(2,))) - model.compile(loss='mse', optimizer='sgd') - - with self.assertRaises(ValueError): - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False) - with self.assertRaises(ValueError): - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_queue_size=10, - use_multiprocessing=False, - validation_data=custom_generator(), - validation_steps=10) - with self.assertRaises(AttributeError): - model.predict_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - with self.assertRaises(ValueError): - model.evaluate_generator(custom_generator(), - steps=5, - max_queue_size=10, - use_multiprocessing=False) - - def test_training_with_sequences(self): - - class DummySequence(keras.utils.Sequence): - - def __getitem__(self, idx): - return np.zeros([10, 2]), np.ones([10]) - - def __len__(self): - return 10 - - arr_data = np.random.random((50, 2)) - arr_labels = np.random.random((50,)) - arr_sample_weights = np.random.random((50,)) - - def custom_generator(): - batch_size = 10 - num_samples = 50 - while True: - batch_index = np.random.randint(0, num_samples - batch_size) - start = batch_index - end = start + batch_size - x = arr_data[start: end] - y = arr_labels[start: end] - w = arr_sample_weights[start: end] - yield x, y, w - - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_shape=(2,))) - model.compile(loss='mse', optimizer='sgd') - - model.fit_generator(DummySequence(), - steps_per_epoch=10, - validation_data=custom_generator(), - validation_steps=1, - max_queue_size=10, - workers=0, - use_multiprocessing=True) - model.fit_generator(DummySequence(), - steps_per_epoch=10, - validation_data=custom_generator(), - validation_steps=1, - max_queue_size=10, - workers=0, - use_multiprocessing=False) - - @tf_test_util.run_in_graph_and_eager_modes - def test_generator_input_to_fit_eval_predict(self): - val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) - - def custom_generator(): - while True: - yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) - - inputs = keras.layers.Input(shape=(10,)) - x = keras.layers.Dense(10, activation='relu')(inputs) - outputs = keras.layers.Dense(1, activation='sigmoid')(x) - model = keras.Model(inputs, outputs) - - model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') - model.fit( - custom_generator(), - steps_per_epoch=2, - validation_data=val_data, - epochs=2) - model.evaluate(custom_generator(), steps=2) - model.predict(custom_generator(), steps=2) - - @tf_test_util.run_in_graph_and_eager_modes - def test_sequence_input_to_fit_eval_predict(self): - val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) - - class CustomSequence(keras.utils.Sequence): - - def __getitem__(self, idx): - return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) - - def __len__(self): - return 2 - - inputs = keras.layers.Input(shape=(10,)) - x = keras.layers.Dense(10, activation='relu')(inputs) - outputs = keras.layers.Dense(1, activation='sigmoid')(x) - model = keras.Model(inputs, outputs) - - model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') - model.fit(CustomSequence(), validation_data=val_data, epochs=2) - model.evaluate(CustomSequence()) - model.predict(CustomSequence()) - - with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'): - model.fit(CustomSequence(), y=np.ones([10, 1])) - - with self.assertRaisesRegexp(ValueError, - '`sample_weight` argument is not supported'): - model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) - - class TestTrainingUtils(test.TestCase): def test_check_array_lengths(self): diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index b736daa46de..01a9d61a84c 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -30,7 +30,6 @@ import sys import tarfile import threading import time -import traceback import zipfile import numpy as np @@ -117,16 +116,16 @@ def _extract_archive(file_path, path='.', archive_format='auto'): """ if archive_format is None: return False - if archive_format is 'auto': + if archive_format == 'auto': archive_format = ['tar', 'zip'] if isinstance(archive_format, six.string_types): archive_format = [archive_format] for archive_type in archive_format: - if archive_type is 'tar': + if archive_type == 'tar': open_fn = tarfile.open is_match_fn = tarfile.is_tarfile - if archive_type is 'zip': + if archive_type == 'zip': open_fn = zipfile.ZipFile is_match_fn = zipfile.is_zipfile @@ -237,7 +236,7 @@ def get_file(fname, def dl_progress(count, block_size, total_size): if ProgressTracker.progbar is None: - if total_size is -1: + if total_size == -1: total_size = None ProgressTracker.progbar = Progbar(total_size) else: @@ -288,7 +287,7 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535): Returns: The file hash """ - if (algorithm is 'sha256') or (algorithm is 'auto' and len(hash) is 64): + if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64): hasher = hashlib.sha256() else: hasher = hashlib.md5() @@ -314,8 +313,7 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): Returns: Whether the file is valid """ - if ((algorithm is 'sha256') or - (algorithm is 'auto' and len(file_hash) is 64)): + if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64): hasher = 'sha256' else: hasher = 'md5' @@ -400,14 +398,23 @@ class Sequence(object): pass def __iter__(self): - """Creates an infinite generator that iterate over the Sequence. + """Create a generator that iterate over the Sequence.""" + for item in (self[i] for i in range(len(self))): + yield item - Yields: - Sequence items. - """ - while True: - for item in (self[i] for i in range(len(self))): - yield item + +def iter_sequence_infinite(seq): + """Iterates indefinitely over a Sequence. + + Arguments: + seq: Sequence instance. + + Yields: + Batches of data from the Sequence. + """ + while True: + for item in seq: + yield item # Global variables to be shared across processes @@ -445,7 +452,7 @@ class SequenceEnqueuer(object): The task of an Enqueuer is to use parallelism to speed up preprocessing. This is done with processes or threads. - Examples: + Example: ```python enqueuer = SequenceEnqueuer(...) @@ -458,61 +465,10 @@ class SequenceEnqueuer(object): ``` The `enqueuer.get()` should be an infinite stream of datas. - """ - @abstractmethod - def is_running(self): - raise NotImplementedError - - @abstractmethod - def start(self, workers=1, max_queue_size=10): - """Starts the handler's workers. - - Arguments: - workers: number of worker threads - max_queue_size: queue size - (when full, threads could block on `put()`). - """ - raise NotImplementedError - - @abstractmethod - def stop(self, timeout=None): - """Stop running threads and wait for them to exit, if necessary. - - Should be called by the same thread which called start(). - - Arguments: - timeout: maximum time to wait on thread.join() - """ - raise NotImplementedError - - @abstractmethod - def get(self): - """Creates a generator to extract data from the queue. - - Skip the data if it is `None`. - - Returns: - Generator yielding tuples `(inputs, targets)` - or `(inputs, targets, sample_weights)`. - """ - raise NotImplementedError - - -@tf_export('keras.utils.OrderedEnqueuer') -class OrderedEnqueuer(SequenceEnqueuer): - """Builds a Enqueuer from a Sequence. - - Used in `fit_generator`, `evaluate_generator`, `predict_generator`. - - Arguments: - sequence: A `keras.utils.data_utils.Sequence` object. - use_multiprocessing: use multiprocessing if True, otherwise threading - shuffle: whether to shuffle the data at the beginning of each epoch - """ - - def __init__(self, sequence, use_multiprocessing=False, shuffle=False): + def __init__(self, sequence, + use_multiprocessing=False): self.sequence = sequence self.use_multiprocessing = use_multiprocessing @@ -535,7 +491,6 @@ class OrderedEnqueuer(SequenceEnqueuer): self.uid = _SEQUENCE_COUNTER.value _SEQUENCE_COUNTER.value += 1 - self.shuffle = shuffle self.workers = 0 self.executor_fn = None self.queue = None @@ -546,16 +501,15 @@ class OrderedEnqueuer(SequenceEnqueuer): return self.stop_signal is not None and not self.stop_signal.is_set() def start(self, workers=1, max_queue_size=10): - """Start the handler's workers. + """Starts the handler's workers. Arguments: - workers: number of worker threads + workers: Number of workers. max_queue_size: queue size (when full, workers could block on `put()`) """ if self.use_multiprocessing: - self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda - workers, initializer=init_pool, initargs=(seqs,)) + self.executor_fn = self._get_executor_init(workers) else: # We do not need the init since it's threads. self.executor_fn = lambda _: ThreadPool(workers) @@ -566,6 +520,87 @@ class OrderedEnqueuer(SequenceEnqueuer): self.run_thread.daemon = True self.run_thread.start() + def _send_sequence(self): + """Sends current Iterable to all workers.""" + # For new processes that may spawn + _SHARED_SEQUENCES[self.uid] = self.sequence + + def stop(self, timeout=None): + """Stops running threads and wait for them to exit, if necessary. + + Should be called by the same thread which called `start()`. + + Arguments: + timeout: maximum time to wait on `thread.join()` + """ + self.stop_signal.set() + with self.queue.mutex: + self.queue.queue.clear() + self.queue.unfinished_tasks = 0 + self.queue.not_full.notify() + self.run_thread.join(timeout) + _SHARED_SEQUENCES[self.uid] = None + + @abstractmethod + def _run(self): + """Submits request to the executor and queue the `Future` objects.""" + raise NotImplementedError + + @abstractmethod + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Arguments: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + raise NotImplementedError + + @abstractmethod + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + # Returns + Generator yielding tuples `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + """ + raise NotImplementedError + + +@tf_export('keras.utils.OrderedEnqueuer') +class OrderedEnqueuer(SequenceEnqueuer): + """Builds a Enqueuer from a Sequence. + + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. + + Arguments: + sequence: A `tf.keras.utils.data_utils.Sequence` object. + use_multiprocessing: use multiprocessing if True, otherwise threading + shuffle: whether to shuffle the data at the beginning of each epoch + """ + + def __init__(self, sequence, use_multiprocessing=False, shuffle=False): + super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) + self.shuffle = shuffle + + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Arguments: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + def pool_fn(seqs): + return multiprocessing.Pool(workers, + initializer=init_pool_generator, + initargs=(seqs, self.random_seed)) + return pool_fn + def _wait_queue(self): """Wait for the queue to be empty.""" while True: @@ -615,30 +650,34 @@ class OrderedEnqueuer(SequenceEnqueuer): self.queue.task_done() if inputs is not None: yield inputs - except Exception as e: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except self.stop() - six.raise_from(StopIteration(e), e) + six.reraise(*sys.exc_info()) - def _send_sequence(self): - """Send current Sequence to all workers.""" - # For new processes that may spawn - _SHARED_SEQUENCES[self.uid] = self.sequence - def stop(self, timeout=None): - """Stops running threads and wait for them to exit, if necessary. +def init_pool_generator(gens, random_seed=None): + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = gens - Should be called by the same thread which called `start()`. + if random_seed is not None: + ident = multiprocessing.current_process().ident + np.random.seed(random_seed + ident) - Arguments: - timeout: maximum time to wait on `thread.join()` - """ - self.stop_signal.set() - with self.queue.mutex: - self.queue.queue.clear() - self.queue.unfinished_tasks = 0 - self.queue.not_full.notify() - self.run_thread.join(timeout) - _SHARED_SEQUENCES[self.uid] = None + +def next_sample(uid): + """Gets the next value from the generator `uid`. + + To allow multiple generators to be used at the same time, we use `uid` to + get a specific one. A single generator would cause the validation to + overwrite the training generator. + + Arguments: + uid: int, generator identifier + + Returns: + The next value of generator `uid`. + """ + return six.next(_SHARED_SEQUENCES[uid]) @tf_export('keras.utils.GeneratorEnqueuer') @@ -658,145 +697,36 @@ class GeneratorEnqueuer(SequenceEnqueuer): will be incremented by one for each worker. """ - def __init__(self, - generator, + def __init__(self, sequence, use_multiprocessing=False, - wait_time=0.05, - seed=None): - self.wait_time = wait_time - self._generator = generator - if os.name is 'nt' and use_multiprocessing is True: - # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`: - # `TypeError: can't pickle generator objects` - # => Suggest multithreading instead of multiprocessing on Windows - raise ValueError('Using a generator with `use_multiprocessing=True`' - ' is not supported on Windows (no marshalling of' - ' generators across process boundaries). Instead,' - ' use single thread/process or multithreading.') - else: - self._use_multiprocessing = use_multiprocessing - self._threads = [] - self._stop_event = None - self._manager = None - self.queue = None - self.seed = seed + random_seed=None): + super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing) + self.random_seed = random_seed - def _data_generator_task(self): - if self._use_multiprocessing is False: - while not self._stop_event.is_set(): - with self.genlock: - try: - if (self.queue is not None and - self.queue.qsize() < self.max_queue_size): - # On all OSes, avoid **SYSTEMATIC** error - # in multithreading mode: - # `ValueError: generator already executing` - # => Serialize calls to - # infinite iterator/generator's next() function - generator_output = next(self._generator) - self.queue.put((True, generator_output)) - else: - time.sleep(self.wait_time) - except StopIteration: - break - except Exception as e: # pylint: disable=broad-except - # Can't pickle tracebacks. - # As a compromise, print the traceback and pickle None instead. - if not hasattr(e, '__traceback__'): - setattr(e, '__traceback__', sys.exc_info()[2]) - self.queue.put((False, e)) - self._stop_event.set() - break - else: - while not self._stop_event.is_set(): - try: - if (self.queue is not None and - self.queue.qsize() < self.max_queue_size): - generator_output = next(self._generator) - self.queue.put((True, generator_output)) - else: - time.sleep(self.wait_time) - except StopIteration: - break - except Exception as e: # pylint: disable=broad-except - # Can't pickle tracebacks. - # As a compromise, print the traceback and pickle None instead. - traceback.print_exc() - setattr(e, '__traceback__', None) - self.queue.put((False, e)) - self._stop_event.set() - break - - def start(self, workers=1, max_queue_size=10): - """Kicks off threads which add data from the generator into the queue. + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. Arguments: - workers: number of worker threads - max_queue_size: queue size - (when full, threads could block on `put()`) + workers: Number of works. + + Returns: + A Function to initialize the pool """ - try: - self.max_queue_size = max_queue_size - if self._use_multiprocessing: - self._manager = multiprocessing.Manager() - self.queue = self._manager.Queue(maxsize=max_queue_size) - self._stop_event = multiprocessing.Event() - else: - # On all OSes, avoid **SYSTEMATIC** error in multithreading mode: - # `ValueError: generator already executing` - # => Serialize calls to infinite iterator/generator's next() function - self.genlock = threading.Lock() - self.queue = queue.Queue(maxsize=max_queue_size) - self._stop_event = threading.Event() + def pool_fn(seqs): + return multiprocessing.Pool(workers, + initializer=init_pool_generator, + initargs=(seqs, self.random_seed)) + return pool_fn - for _ in range(workers): - if self._use_multiprocessing: - # Reset random seed else all children processes - # share the same seed - np.random.seed(self.seed) - thread = multiprocessing.Process(target=self._data_generator_task) - thread.daemon = True - if self.seed is not None: - self.seed += 1 - else: - thread = threading.Thread(target=self._data_generator_task) - self._threads.append(thread) - thread.start() - except: - self.stop() - raise - - def is_running(self): - return self._stop_event is not None and not self._stop_event.is_set() - - def stop(self, timeout=None): - """Stops running threads and wait for them to exit, if necessary. - - Should be called by the same thread which called `start()`. - - Arguments: - timeout: maximum time to wait on `thread.join()`. - """ - if self.is_running(): - self._stop_event.set() - - for thread in self._threads: - if self._use_multiprocessing: - if thread.is_alive(): - thread.terminate() - else: - # The thread.is_alive() test is subject to a race condition: - # the thread could terminate right after the test and before the - # join, rendering this test meaningless -> Call thread.join() - # always, which is ok no matter what the status of the thread. - thread.join(timeout) - - if self._manager: - self._manager.shutdown() - - self._threads = [] - self._stop_event = None - self.queue = None + def _run(self): + """Submits request to the executor and queue the `Future` objects.""" + self._send_sequence() # Share the initial generator + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + while True: + if self.stop_signal.is_set(): + return + self.queue.put( + executor.apply_async(next_sample, (self.uid,)), block=True) def get(self): """Creates a generator to extract data from the queue. @@ -808,24 +738,30 @@ class GeneratorEnqueuer(SequenceEnqueuer): `(inputs, targets)` or `(inputs, targets, sample_weights)`. """ - while self.is_running(): - if not self.queue.empty(): - success, value = self.queue.get() - # Rethrow any exceptions found in the queue - if not success: - six.reraise(value.__class__, value, value.__traceback__) - # Yield regular values - if value is not None: - yield value - else: - all_finished = all([not thread.is_alive() for thread in self._threads]) - if all_finished and self.queue.empty(): - raise StopIteration() - else: - time.sleep(self.wait_time) - - # Make sure to rethrow the first exception in the queue, if any - while not self.queue.empty(): - success, value = self.queue.get() - if not success: - six.reraise(value.__class__, value, value.__traceback__) + try: + while self.is_running(): + inputs = self.queue.get(block=True).get() + self.queue.task_done() + if inputs is not None: + yield inputs + except StopIteration: + # Special case for finite generators + last_ones = [] + while self.queue.qsize() > 0: + last_ones.append(self.queue.get(block=True)) + # Wait for them to complete + for f in last_ones: + f.wait() + # Keep the good ones + last_ones = [future.get() for future in last_ones if future.successful()] + for inputs in last_ones: + if inputs is not None: + yield inputs + except Exception as e: # pylint: disable=broad-except + self.stop() + if 'generator already executing' in str(e): + raise RuntimeError( + 'Your generator is NOT thread-safe. ' + 'Keras requires a thread-safe generator when ' + '`use_multiprocessing=False, workers > 1`. ') + six.reraise(*sys.exc_info()) diff --git a/tensorflow/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py index 395df7e0e78..cc95803d6d6 100644 --- a/tensorflow/python/keras/utils/data_utils_test.py +++ b/tensorflow/python/keras/utils/data_utils_test.py @@ -228,7 +228,7 @@ class TestEnqueuers(test.TestCase): FaultSequence(), use_multiprocessing=False) enqueuer.start(3, 10) gen_output = enqueuer.get() - with self.assertRaises(StopIteration): + with self.assertRaises(IndexError): next(gen_output) def test_ordered_enqueuer_fail_processes(self): @@ -236,7 +236,7 @@ class TestEnqueuers(test.TestCase): FaultSequence(), use_multiprocessing=True) enqueuer.start(3, 10) gen_output = enqueuer.get() - with self.assertRaises(StopIteration): + with self.assertRaises(IndexError): next(gen_output) def test_on_epoch_end_processes(self): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt index 939fd547d06..6f5ad2dc963 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt index a9e499d1009..aa36d66f921 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt index 939fd547d06..6f5ad2dc963 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt index a9e499d1009..aa36d66f921 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "get" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index a98c15d961f..503e602198e 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -104,7 +104,8 @@ do_pylint() { "^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ "^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\ "^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\ -"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned" +"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\ +"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""