diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index b0741acfe30..4dfc28fd40f 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -912,6 +912,7 @@ class KerasSequenceAdapter(GeneratorDataAdapter): self._size = len(x) self._shuffle_sequence = shuffle self._keras_sequence = x + self._enqueuer = None super(KerasSequenceAdapter, self).__init__( x, shuffle=False, # Shuffle is handed in the _make_callable override. @@ -929,11 +930,11 @@ class KerasSequenceAdapter(GeneratorDataAdapter): max_queue_size): if workers > 1 or (workers > 0 and use_multiprocessing): def generator_fn(): - enqueuer = data_utils.OrderedEnqueuer( + self._enqueuer = data_utils.OrderedEnqueuer( x, use_multiprocessing=use_multiprocessing, shuffle=self._shuffle_sequence) - enqueuer.start(workers=workers, max_queue_size=max_queue_size) - return enqueuer.get() + self._enqueuer.start(workers=workers, max_queue_size=max_queue_size) + return self._enqueuer.get() else: def generator_fn(): order = range(len(x)) @@ -954,6 +955,8 @@ class KerasSequenceAdapter(GeneratorDataAdapter): return True def on_epoch_end(self): + if self._enqueuer: + self._enqueuer.stop() self._keras_sequence.on_epoch_end() diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 5224356e877..73ffc19d293 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -678,7 +678,7 @@ class SequenceEnqueuer(object): for data in datas: # Use the inputs; training, evaluating, predicting. # ... stop sometime. - enqueuer.close() + enqueuer.stop() ``` The `enqueuer.get()` should be an infinite stream of datas.