Delete tracking of model's optimizer if include_optimizer=False when saving.

PiperOrigin-RevId: 359145092
Change-Id: Id337879de49ba9197876c487e4ea7c76318d8959
This commit is contained in:
Monica Song 2021-02-23 15:13:18 -08:00 committed by TensorFlower Gardener
parent 00cf670ff6
commit 5931dc3cb3
6 changed files with 51 additions and 14 deletions

View File

@ -2794,8 +2794,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# For any super.__delattr__() call, we will directly use the implementation
# in Trackable and skip the behavior in AutoTrackable. The Layer was
# originally use Trackable as base class, the change of using Module as base
# class forced us to have AutoTrackable in the class hierarchy. Skipping
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
# class forced us to have AutoTrackable in the class hierarchy.
#
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
# __setattr__ in AutoTrackable may be unsustainable.
existing_value = getattr(self, name, None)
# If this value is replacing an existing object assigned to an attribute, we
@ -2901,8 +2903,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
backend.track_variable(val)
# Skip the auto trackable from tf.Module to keep status quo. See the comment
# at __delattr__.
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
# quo. See the comment at __delattr__.
super(tracking.AutoTrackable, self).__setattr__(name, value)
def _gather_children_attribute(self, attribute):

View File

@ -2148,8 +2148,10 @@ class Layer(base_layer.Layer):
# For any super.__delattr__() call, we will directly use the implementation
# in Trackable and skip the behavior in AutoTrackable. The Layer was
# originally use Trackable as base class, the change of using Module as base
# class forced us to have AutoTrackable in the class hierarchy. Skipping
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
# class forced us to have AutoTrackable in the class hierarchy.
#
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
# __setattr__ in AutoTrackable may be unsustainable.
existing_value = getattr(self, name, None)
# If this value is replacing an existing object assigned to an attribute, we
@ -2257,8 +2259,8 @@ class Layer(base_layer.Layer):
backend.track_variable(val)
# Skip the auto trackable from tf.Module to keep status quo. See the comment
# at __delattr__.
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
# quo. See the comment at __delattr__.
super(tracking.AutoTrackable, self).__setattr__(name, value)
# This is a hack so that the is_layer (within

View File

@ -303,6 +303,28 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
self.assertAllClose(batch_loss, new_batch_loss)
@combinations.generate(combinations.combine(mode=['eager', 'graph']))
def test_save_include_optimizer_false(self):
def get_variables(file_name):
reader = training_module.load_checkpoint(
os.path.join(file_name, 'variables/variables'))
shape_from_key = reader.get_variable_to_shape_map()
return sorted(shape_from_key.keys())
model = keras.models.Sequential()
model.add(keras.layers.Dense(1))
model.compile('adam', loss='mse')
x, y = np.ones((10, 10)), np.ones((10, 1))
model.train_on_batch(x, y)
path = os.path.join(self.get_temp_dir(), 'no_optimizer')
model.save(path, save_format='tf', include_optimizer=False)
variables = get_variables(path)
for v in variables:
self.assertNotIn('optimizer', v)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_model_with_custom_object(self):
with generic_utils.custom_object_scope(), self.cached_session():

View File

@ -50,7 +50,6 @@ from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking.tracking import delete_tracking
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -284,7 +283,7 @@ class KerasObjectLoader(object):
# loading layers from the config, such as variables.
continue
for name in PUBLIC_ATTRIBUTES:
delete_tracking(node, name)
node._delete_tracking(name) # pylint: disable=protected-access
if isinstance(node, functional_lib.Functional):
# Delete the temporary layer dependencies, which were used to restore
@ -294,7 +293,7 @@ class KerasObjectLoader(object):
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
for name in dependencies:
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
delete_tracking(node, name)
node._delete_tracking(name) # pylint: disable=protected-access
def _add_children_recreated_from_config(self, obj, proto, node_id):
"""Recursively records objects recreated from config."""

View File

@ -32,7 +32,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import save as save_lib
# To avoid circular dependencies between keras/engine and keras/saving,
# code in keras/saving must delay imports.
@ -81,6 +80,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
if not include_optimizer:
orig_optimizer = model.optimizer
model.optimizer = None
# TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__
# calls AutoTrackable's __delattr__.
model._delete_tracking("optimizer") # pylint: disable=protected-access
# Trace all functions and signatures with `training=0` instead of using an
# already-set learning phase placeholder.

View File

@ -92,8 +92,7 @@ class AutoTrackable(base.Trackable):
super(AutoTrackable, self).__setattr__(name, value)
def __delattr__(self, name):
self._maybe_initialize_trackable()
delete_tracking(self, name)
self._delete_tracking(name)
super(AutoTrackable, self).__delattr__(name)
def _no_dependency(self, value):
@ -125,6 +124,17 @@ class AutoTrackable(base.Trackable):
functions[attribute_name] = attribute_value
return functions
def _delete_tracking(self, name):
"""Removes the tracking of name."""
self._maybe_initialize_trackable()
if name in self._unconditional_dependency_names:
del self._unconditional_dependency_names[name]
for index, (dep_name, _) in enumerate(
self._unconditional_checkpoint_dependencies):
if dep_name == name:
del self._unconditional_checkpoint_dependencies[index]
break
def delete_tracking(obj, name):
"""Removes the tracking of name from object."""