Don't let Layers and Variables build up in Layer's __setattr__ tracking
PiperOrigin-RevId: 235628910
This commit is contained in:
parent
3ef4e3c24e
commit
84fdb93a93
@ -49,6 +49,7 @@ from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking import object_identity
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -1685,6 +1686,52 @@ class Layer(trackable.Trackable):
|
||||
|
||||
return nest.map_structure(_make_placeholder_like, output_shapes)
|
||||
|
||||
@property
|
||||
def _obj_reference_counts(self):
|
||||
"""A dictionary counting the number of attributes referencing an object."""
|
||||
if not hasattr(self, '_obj_reference_counts_dict'):
|
||||
super(Layer, self).__setattr__(
|
||||
'_obj_reference_counts_dict',
|
||||
object_identity.ObjectIdentityDictionary())
|
||||
return self._obj_reference_counts_dict
|
||||
|
||||
def __delattr__(self, name):
|
||||
existing_value = getattr(self, name, None)
|
||||
|
||||
# If this value is replacing an existing object assigned to an attribute, we
|
||||
# should clean it out to avoid leaking memory. First we check if there are
|
||||
# other attributes referencing it.
|
||||
reference_counts = self._obj_reference_counts
|
||||
if existing_value not in reference_counts:
|
||||
super(Layer, self).__delattr__(name)
|
||||
return
|
||||
|
||||
reference_count = reference_counts[existing_value]
|
||||
if reference_count > 1:
|
||||
# There are other remaining references. We can't remove this object from
|
||||
# _layers etc.
|
||||
reference_counts[existing_value] = reference_count - 1
|
||||
super(Layer, self).__delattr__(name)
|
||||
return
|
||||
else:
|
||||
# This is the last remaining reference.
|
||||
del reference_counts[existing_value]
|
||||
|
||||
super(Layer, self).__delattr__(name)
|
||||
|
||||
if (isinstance(existing_value, Layer)
|
||||
or trackable_layer_utils.has_weights(existing_value)):
|
||||
super(Layer, self).__setattr__(
|
||||
'_layers',
|
||||
[l for l in self._layers if l is not existing_value])
|
||||
if isinstance(existing_value, tf_variables.Variable):
|
||||
super(Layer, self).__setattr__(
|
||||
'_trainable_weights',
|
||||
[w for w in self._trainable_weights if w is not existing_value])
|
||||
super(Layer, self).__setattr__(
|
||||
'_non_trainable_weights',
|
||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (not getattr(self, '_setattr_tracking', True) or
|
||||
getattr(self, '_is_graph_network', False)):
|
||||
@ -1695,6 +1742,16 @@ class Layer(trackable.Trackable):
|
||||
value = data_structures.sticky_attribute_assignment(
|
||||
trackable=self, value=value, name=name)
|
||||
|
||||
reference_counts = self._obj_reference_counts
|
||||
reference_counts[value] = reference_counts.get(value, 0) + 1
|
||||
|
||||
# Clean out the old attribute, which clears _layers and _trainable_weights
|
||||
# if necessary.
|
||||
try:
|
||||
self.__delattr__(name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Append value to self._layers if relevant
|
||||
if (isinstance(value, Layer) or
|
||||
trackable_layer_utils.has_weights(value)):
|
||||
|
@ -547,6 +547,27 @@ class NestedTrackingTest(test.TestCase):
|
||||
self.assertEqual(len(layer.losses), 3)
|
||||
self.assertEqual(len(layer.updates), 3)
|
||||
|
||||
def test_attribute_reassignment(self):
|
||||
l = keras.layers.Layer()
|
||||
l.a = keras.layers.Layer()
|
||||
l.a = []
|
||||
l.a = variables.Variable(1.)
|
||||
l.a = keras.layers.Layer()
|
||||
last_assignment = keras.layers.Layer()
|
||||
l.a = last_assignment
|
||||
l.b = variables.Variable(1.)
|
||||
del l.b
|
||||
l.c = keras.layers.Layer()
|
||||
del l.c
|
||||
l.d = last_assignment
|
||||
del l.d
|
||||
self.assertEqual([last_assignment], l._layers)
|
||||
self.assertEqual([], l.trainable_weights)
|
||||
self.assertEqual([], l.non_trainable_weights)
|
||||
self.assertEqual([], l.weights)
|
||||
del l.a
|
||||
self.assertEqual([], l._layers)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NameScopingTest(keras_parameterized.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user