Delete tracking of model's optimizer if include_optimizer=False when saving.
PiperOrigin-RevId: 359145092 Change-Id: Id337879de49ba9197876c487e4ea7c76318d8959
This commit is contained in:
parent
00cf670ff6
commit
5931dc3cb3
@ -2794,8 +2794,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# For any super.__delattr__() call, we will directly use the implementation
|
# For any super.__delattr__() call, we will directly use the implementation
|
||||||
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
||||||
# originally use Trackable as base class, the change of using Module as base
|
# originally use Trackable as base class, the change of using Module as base
|
||||||
# class forced us to have AutoTrackable in the class hierarchy. Skipping
|
# class forced us to have AutoTrackable in the class hierarchy.
|
||||||
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
|
#
|
||||||
|
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
|
||||||
|
# __setattr__ in AutoTrackable may be unsustainable.
|
||||||
existing_value = getattr(self, name, None)
|
existing_value = getattr(self, name, None)
|
||||||
|
|
||||||
# If this value is replacing an existing object assigned to an attribute, we
|
# If this value is replacing an existing object assigned to an attribute, we
|
||||||
@ -2901,8 +2903,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
|
|
||||||
backend.track_variable(val)
|
backend.track_variable(val)
|
||||||
|
|
||||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
|
||||||
# at __delattr__.
|
# quo. See the comment at __delattr__.
|
||||||
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
||||||
|
|
||||||
def _gather_children_attribute(self, attribute):
|
def _gather_children_attribute(self, attribute):
|
||||||
|
|||||||
@ -2148,8 +2148,10 @@ class Layer(base_layer.Layer):
|
|||||||
# For any super.__delattr__() call, we will directly use the implementation
|
# For any super.__delattr__() call, we will directly use the implementation
|
||||||
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
# in Trackable and skip the behavior in AutoTrackable. The Layer was
|
||||||
# originally use Trackable as base class, the change of using Module as base
|
# originally use Trackable as base class, the change of using Module as base
|
||||||
# class forced us to have AutoTrackable in the class hierarchy. Skipping
|
# class forced us to have AutoTrackable in the class hierarchy.
|
||||||
# the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
|
#
|
||||||
|
# TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
|
||||||
|
# __setattr__ in AutoTrackable may be unsustainable.
|
||||||
existing_value = getattr(self, name, None)
|
existing_value = getattr(self, name, None)
|
||||||
|
|
||||||
# If this value is replacing an existing object assigned to an attribute, we
|
# If this value is replacing an existing object assigned to an attribute, we
|
||||||
@ -2257,8 +2259,8 @@ class Layer(base_layer.Layer):
|
|||||||
|
|
||||||
backend.track_variable(val)
|
backend.track_variable(val)
|
||||||
|
|
||||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
# TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
|
||||||
# at __delattr__.
|
# quo. See the comment at __delattr__.
|
||||||
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
super(tracking.AutoTrackable, self).__setattr__(name, value)
|
||||||
|
|
||||||
# This is a hack so that the is_layer (within
|
# This is a hack so that the is_layer (within
|
||||||
|
|||||||
@ -303,6 +303,28 @@ class TestSaveModel(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(batch_loss, new_batch_loss)
|
self.assertAllClose(batch_loss, new_batch_loss)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager', 'graph']))
|
||||||
|
def test_save_include_optimizer_false(self):
|
||||||
|
|
||||||
|
def get_variables(file_name):
|
||||||
|
reader = training_module.load_checkpoint(
|
||||||
|
os.path.join(file_name, 'variables/variables'))
|
||||||
|
shape_from_key = reader.get_variable_to_shape_map()
|
||||||
|
return sorted(shape_from_key.keys())
|
||||||
|
|
||||||
|
model = keras.models.Sequential()
|
||||||
|
model.add(keras.layers.Dense(1))
|
||||||
|
model.compile('adam', loss='mse')
|
||||||
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
|
model.train_on_batch(x, y)
|
||||||
|
|
||||||
|
path = os.path.join(self.get_temp_dir(), 'no_optimizer')
|
||||||
|
model.save(path, save_format='tf', include_optimizer=False)
|
||||||
|
variables = get_variables(path)
|
||||||
|
|
||||||
|
for v in variables:
|
||||||
|
self.assertNotIn('optimizer', v)
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_saving_model_with_custom_object(self):
|
def test_saving_model_with_custom_object(self):
|
||||||
with generic_utils.custom_object_scope(), self.cached_session():
|
with generic_utils.custom_object_scope(), self.cached_session():
|
||||||
|
|||||||
@ -50,7 +50,6 @@ from tensorflow.python.saved_model import nested_structure_coder
|
|||||||
from tensorflow.python.saved_model import revived_types
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
@ -284,7 +283,7 @@ class KerasObjectLoader(object):
|
|||||||
# loading layers from the config, such as variables.
|
# loading layers from the config, such as variables.
|
||||||
continue
|
continue
|
||||||
for name in PUBLIC_ATTRIBUTES:
|
for name in PUBLIC_ATTRIBUTES:
|
||||||
delete_tracking(node, name)
|
node._delete_tracking(name) # pylint: disable=protected-access
|
||||||
|
|
||||||
if isinstance(node, functional_lib.Functional):
|
if isinstance(node, functional_lib.Functional):
|
||||||
# Delete the temporary layer dependencies, which were used to restore
|
# Delete the temporary layer dependencies, which were used to restore
|
||||||
@ -294,7 +293,7 @@ class KerasObjectLoader(object):
|
|||||||
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
|
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
|
||||||
for name in dependencies:
|
for name in dependencies:
|
||||||
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
|
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
|
||||||
delete_tracking(node, name)
|
node._delete_tracking(name) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _add_children_recreated_from_config(self, obj, proto, node_id):
|
def _add_children_recreated_from_config(self, obj, proto, node_id):
|
||||||
"""Recursively records objects recreated from config."""
|
"""Recursively records objects recreated from config."""
|
||||||
|
|||||||
@ -32,7 +32,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.saved_model import save as save_lib
|
from tensorflow.python.saved_model import save as save_lib
|
||||||
|
|
||||||
|
|
||||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||||
# code in keras/saving must delay imports.
|
# code in keras/saving must delay imports.
|
||||||
|
|
||||||
@ -81,6 +80,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
|||||||
if not include_optimizer:
|
if not include_optimizer:
|
||||||
orig_optimizer = model.optimizer
|
orig_optimizer = model.optimizer
|
||||||
model.optimizer = None
|
model.optimizer = None
|
||||||
|
# TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__
|
||||||
|
# calls AutoTrackable's __delattr__.
|
||||||
|
model._delete_tracking("optimizer") # pylint: disable=protected-access
|
||||||
|
|
||||||
# Trace all functions and signatures with `training=0` instead of using an
|
# Trace all functions and signatures with `training=0` instead of using an
|
||||||
# already-set learning phase placeholder.
|
# already-set learning phase placeholder.
|
||||||
|
|||||||
@ -92,8 +92,7 @@ class AutoTrackable(base.Trackable):
|
|||||||
super(AutoTrackable, self).__setattr__(name, value)
|
super(AutoTrackable, self).__setattr__(name, value)
|
||||||
|
|
||||||
def __delattr__(self, name):
|
def __delattr__(self, name):
|
||||||
self._maybe_initialize_trackable()
|
self._delete_tracking(name)
|
||||||
delete_tracking(self, name)
|
|
||||||
super(AutoTrackable, self).__delattr__(name)
|
super(AutoTrackable, self).__delattr__(name)
|
||||||
|
|
||||||
def _no_dependency(self, value):
|
def _no_dependency(self, value):
|
||||||
@ -125,6 +124,17 @@ class AutoTrackable(base.Trackable):
|
|||||||
functions[attribute_name] = attribute_value
|
functions[attribute_name] = attribute_value
|
||||||
return functions
|
return functions
|
||||||
|
|
||||||
|
def _delete_tracking(self, name):
|
||||||
|
"""Removes the tracking of name."""
|
||||||
|
self._maybe_initialize_trackable()
|
||||||
|
if name in self._unconditional_dependency_names:
|
||||||
|
del self._unconditional_dependency_names[name]
|
||||||
|
for index, (dep_name, _) in enumerate(
|
||||||
|
self._unconditional_checkpoint_dependencies):
|
||||||
|
if dep_name == name:
|
||||||
|
del self._unconditional_checkpoint_dependencies[index]
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def delete_tracking(obj, name):
|
def delete_tracking(obj, name):
|
||||||
"""Removes the tracking of name from object."""
|
"""Removes the tracking of name from object."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user