Don't let Layers and Variables build up in Layer's __setattr__ tracking
PiperOrigin-RevId: 235628910
This commit is contained in:
parent
3ef4e3c24e
commit
84fdb93a93
@ -49,6 +49,7 @@ from tensorflow.python.ops import variables as tf_variables
|
|||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||||
|
from tensorflow.python.training.tracking import object_identity
|
||||||
from tensorflow.python.util import function_utils
|
from tensorflow.python.util import function_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
@ -1685,6 +1686,52 @@ class Layer(trackable.Trackable):
|
|||||||
|
|
||||||
return nest.map_structure(_make_placeholder_like, output_shapes)
|
return nest.map_structure(_make_placeholder_like, output_shapes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _obj_reference_counts(self):
|
||||||
|
"""A dictionary counting the number of attributes referencing an object."""
|
||||||
|
if not hasattr(self, '_obj_reference_counts_dict'):
|
||||||
|
super(Layer, self).__setattr__(
|
||||||
|
'_obj_reference_counts_dict',
|
||||||
|
object_identity.ObjectIdentityDictionary())
|
||||||
|
return self._obj_reference_counts_dict
|
||||||
|
|
||||||
|
def __delattr__(self, name):
|
||||||
|
existing_value = getattr(self, name, None)
|
||||||
|
|
||||||
|
# If this value is replacing an existing object assigned to an attribute, we
|
||||||
|
# should clean it out to avoid leaking memory. First we check if there are
|
||||||
|
# other attributes referencing it.
|
||||||
|
reference_counts = self._obj_reference_counts
|
||||||
|
if existing_value not in reference_counts:
|
||||||
|
super(Layer, self).__delattr__(name)
|
||||||
|
return
|
||||||
|
|
||||||
|
reference_count = reference_counts[existing_value]
|
||||||
|
if reference_count > 1:
|
||||||
|
# There are other remaining references. We can't remove this object from
|
||||||
|
# _layers etc.
|
||||||
|
reference_counts[existing_value] = reference_count - 1
|
||||||
|
super(Layer, self).__delattr__(name)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# This is the last remaining reference.
|
||||||
|
del reference_counts[existing_value]
|
||||||
|
|
||||||
|
super(Layer, self).__delattr__(name)
|
||||||
|
|
||||||
|
if (isinstance(existing_value, Layer)
|
||||||
|
or trackable_layer_utils.has_weights(existing_value)):
|
||||||
|
super(Layer, self).__setattr__(
|
||||||
|
'_layers',
|
||||||
|
[l for l in self._layers if l is not existing_value])
|
||||||
|
if isinstance(existing_value, tf_variables.Variable):
|
||||||
|
super(Layer, self).__setattr__(
|
||||||
|
'_trainable_weights',
|
||||||
|
[w for w in self._trainable_weights if w is not existing_value])
|
||||||
|
super(Layer, self).__setattr__(
|
||||||
|
'_non_trainable_weights',
|
||||||
|
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if (not getattr(self, '_setattr_tracking', True) or
|
if (not getattr(self, '_setattr_tracking', True) or
|
||||||
getattr(self, '_is_graph_network', False)):
|
getattr(self, '_is_graph_network', False)):
|
||||||
@ -1695,6 +1742,16 @@ class Layer(trackable.Trackable):
|
|||||||
value = data_structures.sticky_attribute_assignment(
|
value = data_structures.sticky_attribute_assignment(
|
||||||
trackable=self, value=value, name=name)
|
trackable=self, value=value, name=name)
|
||||||
|
|
||||||
|
reference_counts = self._obj_reference_counts
|
||||||
|
reference_counts[value] = reference_counts.get(value, 0) + 1
|
||||||
|
|
||||||
|
# Clean out the old attribute, which clears _layers and _trainable_weights
|
||||||
|
# if necessary.
|
||||||
|
try:
|
||||||
|
self.__delattr__(name)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Append value to self._layers if relevant
|
# Append value to self._layers if relevant
|
||||||
if (isinstance(value, Layer) or
|
if (isinstance(value, Layer) or
|
||||||
trackable_layer_utils.has_weights(value)):
|
trackable_layer_utils.has_weights(value)):
|
||||||
|
|||||||
@ -547,6 +547,27 @@ class NestedTrackingTest(test.TestCase):
|
|||||||
self.assertEqual(len(layer.losses), 3)
|
self.assertEqual(len(layer.losses), 3)
|
||||||
self.assertEqual(len(layer.updates), 3)
|
self.assertEqual(len(layer.updates), 3)
|
||||||
|
|
||||||
|
def test_attribute_reassignment(self):
|
||||||
|
l = keras.layers.Layer()
|
||||||
|
l.a = keras.layers.Layer()
|
||||||
|
l.a = []
|
||||||
|
l.a = variables.Variable(1.)
|
||||||
|
l.a = keras.layers.Layer()
|
||||||
|
last_assignment = keras.layers.Layer()
|
||||||
|
l.a = last_assignment
|
||||||
|
l.b = variables.Variable(1.)
|
||||||
|
del l.b
|
||||||
|
l.c = keras.layers.Layer()
|
||||||
|
del l.c
|
||||||
|
l.d = last_assignment
|
||||||
|
del l.d
|
||||||
|
self.assertEqual([last_assignment], l._layers)
|
||||||
|
self.assertEqual([], l.trainable_weights)
|
||||||
|
self.assertEqual([], l.non_trainable_weights)
|
||||||
|
self.assertEqual([], l.weights)
|
||||||
|
del l.a
|
||||||
|
self.assertEqual([], l._layers)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class NameScopingTest(keras_parameterized.TestCase):
|
class NameScopingTest(keras_parameterized.TestCase):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user