Delete tracking of model's optimizer if include_optimizer=False
when saving.
PiperOrigin-RevId: 359145092 Change-Id: Id337879de49ba9197876c487e4ea7c76318d8959
This commit is contained in:
parent
00cf670ff6
commit
5931dc3cb3
@ -2794,8 +2794,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# For any super.__delattr__() call, we will directly use the implementation
|
||||
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
||||
# originally use Trackable as base class, the change of using Module as base
|
||||
# class forced us to have AutoTrackable in the class hierarchy. Skipping
|
||||
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
|
||||
# class forced us to have AutoTrackable in the class hierarchy.
|
||||
#
|
||||
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
|
||||
# __setattr__ in AutoTrackable may be unsustainable.
|
||||
existing_value = getattr(self, name, None)
|
||||
|
||||
# If this value is replacing an existing object assigned to an attribute, we
|
||||
@ -2901,8 +2903,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
|
||||
backend.track_variable(val)
|
||||
|
||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
||||
# at __delattr__.
|
||||
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
|
||||
# quo. See the comment at __delattr__.
|
||||
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
||||
|
||||
def _gather_children_attribute(self, attribute):
|
||||
|
@ -2148,8 +2148,10 @@ class Layer(base_layer.Layer):
|
||||
# For any super.__delattr__() call, we will directly use the implementation
|
||||
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
||||
# originally use Trackable as base class, the change of using Module as base
|
||||
# class forced us to have AutoTrackable in the class hierarchy. Skipping
|
||||
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
|
||||
# class forced us to have AutoTrackable in the class hierarchy.
|
||||
#
|
||||
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
|
||||
# __setattr__ in AutoTrackable may be unsustainable.
|
||||
existing_value = getattr(self, name, None)
|
||||
|
||||
# If this value is replacing an existing object assigned to an attribute, we
|
||||
@ -2257,8 +2259,8 @@ class Layer(base_layer.Layer):
|
||||
|
||||
backend.track_variable(val)
|
||||
|
||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
||||
# at __delattr__.
|
||||
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
|
||||
# quo. See the comment at __delattr__.
|
||||
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
||||
|
||||
# This is a hack so that the is_layer (within
|
||||
|
@ -303,6 +303,28 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllClose(batch_loss, new_batch_loss)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager', 'graph']))
|
||||
def test_save_include_optimizer_false(self):
|
||||
|
||||
def get_variables(file_name):
|
||||
reader = training_module.load_checkpoint(
|
||||
os.path.join(file_name, 'variables/variables'))
|
||||
shape_from_key = reader.get_variable_to_shape_map()
|
||||
return sorted(shape_from_key.keys())
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1))
|
||||
model.compile('adam', loss='mse')
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
path = os.path.join(self.get_temp_dir(), 'no_optimizer')
|
||||
model.save(path, save_format='tf', include_optimizer=False)
|
||||
variables = get_variables(path)
|
||||
|
||||
for v in variables:
|
||||
self.assertNotIn('optimizer', v)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_saving_model_with_custom_object(self):
|
||||
with generic_utils.custom_object_scope(), self.cached_session():
|
||||
|
@ -50,7 +50,6 @@ from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -284,7 +283,7 @@ class KerasObjectLoader(object):
|
||||
# loading layers from the config, such as variables.
|
||||
continue
|
||||
for name in PUBLIC_ATTRIBUTES:
|
||||
delete_tracking(node, name)
|
||||
node._delete_tracking(name) # pylint: disable=protected-access
|
||||
|
||||
if isinstance(node, functional_lib.Functional):
|
||||
# Delete the temporary layer dependencies, which were used to restore
|
||||
@ -294,7 +293,7 @@ class KerasObjectLoader(object):
|
||||
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
|
||||
for name in dependencies:
|
||||
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
|
||||
delete_tracking(node, name)
|
||||
node._delete_tracking(name) # pylint: disable=protected-access
|
||||
|
||||
def _add_children_recreated_from_config(self, obj, proto, node_id):
|
||||
"""Recursively records objects recreated from config."""
|
||||
|
@ -32,7 +32,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
@ -81,6 +80,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
if not include_optimizer:
|
||||
orig_optimizer = model.optimizer
|
||||
model.optimizer = None
|
||||
# TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__
|
||||
# calls AutoTrackable's __delattr__.
|
||||
model._delete_tracking("optimizer") # pylint: disable=protected-access
|
||||
|
||||
# Trace all functions and signatures with `training=0` instead of using an
|
||||
# already-set learning phase placeholder.
|
||||
|
@ -92,8 +92,7 @@ class AutoTrackable(base.Trackable):
|
||||
super(AutoTrackable, self).__setattr__(name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
self._maybe_initialize_trackable()
|
||||
delete_tracking(self, name)
|
||||
self._delete_tracking(name)
|
||||
super(AutoTrackable, self).__delattr__(name)
|
||||
|
||||
def _no_dependency(self, value):
|
||||
@ -125,6 +124,17 @@ class AutoTrackable(base.Trackable):
|
||||
functions[attribute_name] = attribute_value
|
||||
return functions
|
||||
|
||||
def _delete_tracking(self, name):
|
||||
"""Removes the tracking of name."""
|
||||
self._maybe_initialize_trackable()
|
||||
if name in self._unconditional_dependency_names:
|
||||
del self._unconditional_dependency_names[name]
|
||||
for index, (dep_name, _) in enumerate(
|
||||
self._unconditional_checkpoint_dependencies):
|
||||
if dep_name == name:
|
||||
del self._unconditional_checkpoint_dependencies[index]
|
||||
break
|
||||
|
||||
|
||||
def delete_tracking(obj, name):
|
||||
"""Removes the tracking of name from object."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user