Fixing a memory leak in Keras.
Fixes: https://github.com/tensorflow/tensorflow/issues/37515 PiperOrigin-RevId: 302568217 Change-Id: I28d0eaf3602fea0461901680df24899f135ce649
This commit is contained in:
parent
60d6ea479e
commit
e918c6e6fa
@ -912,6 +912,7 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
|
|||||||
self._size = len(x)
|
self._size = len(x)
|
||||||
self._shuffle_sequence = shuffle
|
self._shuffle_sequence = shuffle
|
||||||
self._keras_sequence = x
|
self._keras_sequence = x
|
||||||
|
self._enqueuer = None
|
||||||
super(KerasSequenceAdapter, self).__init__(
|
super(KerasSequenceAdapter, self).__init__(
|
||||||
x,
|
x,
|
||||||
shuffle=False, # Shuffle is handed in the _make_callable override.
|
shuffle=False, # Shuffle is handed in the _make_callable override.
|
||||||
@ -929,11 +930,11 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
|
|||||||
max_queue_size):
|
max_queue_size):
|
||||||
if workers > 1 or (workers > 0 and use_multiprocessing):
|
if workers > 1 or (workers > 0 and use_multiprocessing):
|
||||||
def generator_fn():
|
def generator_fn():
|
||||||
enqueuer = data_utils.OrderedEnqueuer(
|
self._enqueuer = data_utils.OrderedEnqueuer(
|
||||||
x, use_multiprocessing=use_multiprocessing,
|
x, use_multiprocessing=use_multiprocessing,
|
||||||
shuffle=self._shuffle_sequence)
|
shuffle=self._shuffle_sequence)
|
||||||
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
self._enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
||||||
return enqueuer.get()
|
return self._enqueuer.get()
|
||||||
else:
|
else:
|
||||||
def generator_fn():
|
def generator_fn():
|
||||||
order = range(len(x))
|
order = range(len(x))
|
||||||
@ -954,6 +955,8 @@ class KerasSequenceAdapter(GeneratorDataAdapter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self):
|
||||||
|
if self._enqueuer:
|
||||||
|
self._enqueuer.stop()
|
||||||
self._keras_sequence.on_epoch_end()
|
self._keras_sequence.on_epoch_end()
|
||||||
|
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ class SequenceEnqueuer(object):
|
|||||||
for data in datas:
|
for data in datas:
|
||||||
# Use the inputs; training, evaluating, predicting.
|
# Use the inputs; training, evaluating, predicting.
|
||||||
# ... stop sometime.
|
# ... stop sometime.
|
||||||
enqueuer.close()
|
enqueuer.stop()
|
||||||
```
|
```
|
||||||
|
|
||||||
The `enqueuer.get()` should be an infinite stream of datas.
|
The `enqueuer.get()` should be an infinite stream of datas.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user