Fixing a memory leak in Keras.

Fixes: https://github.com/tensorflow/tensorflow/issues/37515
PiperOrigin-RevId: 302568217
Change-Id: I28d0eaf3602fea0461901680df24899f135ce649
This commit is contained in:
Jiri Simsa 2020-03-23 18:51:46 -07:00 committed by TensorFlower Gardener
parent 60d6ea479e
commit e918c6e6fa
2 changed files with 7 additions and 4 deletions

View File

@ -912,6 +912,7 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
self._size = len(x) self._size = len(x)
self._shuffle_sequence = shuffle self._shuffle_sequence = shuffle
self._keras_sequence = x self._keras_sequence = x
self._enqueuer = None
super(KerasSequenceAdapter, self).__init__( super(KerasSequenceAdapter, self).__init__(
x, x,
shuffle=False, # Shuffle is handed in the _make_callable override. shuffle=False, # Shuffle is handed in the _make_callable override.
@ -929,11 +930,11 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
max_queue_size): max_queue_size):
if workers > 1 or (workers > 0 and use_multiprocessing): if workers > 1 or (workers > 0 and use_multiprocessing):
def generator_fn(): def generator_fn():
enqueuer = data_utils.OrderedEnqueuer( self._enqueuer = data_utils.OrderedEnqueuer(
x, use_multiprocessing=use_multiprocessing, x, use_multiprocessing=use_multiprocessing,
shuffle=self._shuffle_sequence) shuffle=self._shuffle_sequence)
enqueuer.start(workers=workers, max_queue_size=max_queue_size) self._enqueuer.start(workers=workers, max_queue_size=max_queue_size)
return enqueuer.get() return self._enqueuer.get()
else: else:
def generator_fn(): def generator_fn():
order = range(len(x)) order = range(len(x))
@ -954,6 +955,8 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
return True return True
def on_epoch_end(self): def on_epoch_end(self):
if self._enqueuer:
self._enqueuer.stop()
self._keras_sequence.on_epoch_end() self._keras_sequence.on_epoch_end()

View File

@ -678,7 +678,7 @@ class SequenceEnqueuer(object):
for data in datas: for data in datas:
# Use the inputs; training, evaluating, predicting. # Use the inputs; training, evaluating, predicting.
# ... stop sometime. # ... stop sometime.
enqueuer.close() enqueuer.stop()
``` ```
The `enqueuer.get()` should be an infinite stream of datas. The `enqueuer.get()` should be an infinite stream of datas.