Don't let Layers and Variables build up in Layer's __setattr__ tracking

PiperOrigin-RevId: 235628910
This commit is contained in:
Allen Lavoie 2019-02-25 17:01:30 -08:00 committed by TensorFlower Gardener
parent 3ef4e3c24e
commit 84fdb93a93
2 changed files with 78 additions and 0 deletions

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import object_identity
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@ -1685,6 +1686,52 @@ class Layer(trackable.Trackable):
return nest.map_structure(_make_placeholder_like, output_shapes)
@property
def _obj_reference_counts(self):
"""A dictionary counting the number of attributes referencing an object."""
if not hasattr(self, '_obj_reference_counts_dict'):
super(Layer, self).__setattr__(
'_obj_reference_counts_dict',
object_identity.ObjectIdentityDictionary())
return self._obj_reference_counts_dict
def __delattr__(self, name):
existing_value = getattr(self, name, None)
# If this value is replacing an existing object assigned to an attribute, we
# should clean it out to avoid leaking memory. First we check if there are
# other attributes referencing it.
reference_counts = self._obj_reference_counts
if existing_value not in reference_counts:
super(Layer, self).__delattr__(name)
return
reference_count = reference_counts[existing_value]
if reference_count > 1:
# There are other remaining references. We can't remove this object from
# _layers etc.
reference_counts[existing_value] = reference_count - 1
super(Layer, self).__delattr__(name)
return
else:
# This is the last remaining reference.
del reference_counts[existing_value]
super(Layer, self).__delattr__(name)
if (isinstance(existing_value, Layer)
or trackable_layer_utils.has_weights(existing_value)):
super(Layer, self).__setattr__(
'_layers',
[l for l in self._layers if l is not existing_value])
if isinstance(existing_value, tf_variables.Variable):
super(Layer, self).__setattr__(
'_trainable_weights',
[w for w in self._trainable_weights if w is not existing_value])
super(Layer, self).__setattr__(
'_non_trainable_weights',
[w for w in self._non_trainable_weights if w is not existing_value])
def __setattr__(self, name, value):
if (not getattr(self, '_setattr_tracking', True) or
getattr(self, '_is_graph_network', False)):
@ -1695,6 +1742,16 @@ class Layer(trackable.Trackable):
value = data_structures.sticky_attribute_assignment(
trackable=self, value=value, name=name)
reference_counts = self._obj_reference_counts
reference_counts[value] = reference_counts.get(value, 0) + 1
# Clean out the old attribute, which clears _layers and _trainable_weights
# if necessary.
try:
self.__delattr__(name)
except AttributeError:
pass
# Append value to self._layers if relevant
if (isinstance(value, Layer) or
trackable_layer_utils.has_weights(value)):

View File

@ -547,6 +547,27 @@ class NestedTrackingTest(test.TestCase):
self.assertEqual(len(layer.losses), 3)
self.assertEqual(len(layer.updates), 3)
def test_attribute_reassignment(self):
l = keras.layers.Layer()
l.a = keras.layers.Layer()
l.a = []
l.a = variables.Variable(1.)
l.a = keras.layers.Layer()
last_assignment = keras.layers.Layer()
l.a = last_assignment
l.b = variables.Variable(1.)
del l.b
l.c = keras.layers.Layer()
del l.c
l.d = last_assignment
del l.d
self.assertEqual([last_assignment], l._layers)
self.assertEqual([], l.trainable_weights)
self.assertEqual([], l.non_trainable_weights)
self.assertEqual([], l.weights)
del l.a
self.assertEqual([], l._layers)
@test_util.run_all_in_graph_and_eager_modes
class NameScopingTest(keras_parameterized.TestCase):