Add SaveOptions/CheckpointOptions to keras.Models.save_weights and keras_call_backs.ModelCheckpoint.

PiperOrigin-RevId: 316973333
Change-Id: I43f5b59ece4b862db41ab0e99f3c8df0a0d3b901
This commit is contained in:
Ken Franko 2020-06-17 15:06:41 -07:00 committed by TensorFlower Gardener
parent ed2b3d6e1e
commit 7a92859246
17 changed files with 82 additions and 20 deletions

View File

@ -54,7 +54,9 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.saved_model import save_options as save_options_lib
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import keras_export
@ -1115,6 +1117,9 @@ class ModelCheckpoint(Callback):
epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset every
epoch). Defaults to `'epoch'`.
options: Optional `tf.train.CheckpointOptions` object if
`save_weights_only` is true or optional `tf.saved_model.SavedOptions`
object if `save_weights_only` is false.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""
@ -1127,6 +1132,7 @@ class ModelCheckpoint(Callback):
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
**kwargs):
super(ModelCheckpoint, self).__init__()
self._supports_tf_logs = True
@ -1140,6 +1146,20 @@ class ModelCheckpoint(Callback):
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0
if save_weights_only:
if options is None or isinstance(
options, checkpoint_options_lib.CheckpointOptions):
self._options = options or checkpoint_options_lib.CheckpointOptions()
else:
raise TypeError('If save_weights_only is True, then `options` must be'
'either None or a tf.train.CheckpointOptions')
else:
if options is None or isinstance(options, save_options_lib.SaveOptions):
self._options = options or save_options_lib.SaveOptions()
else:
raise TypeError('If save_weights_only is False, then `options` must be'
'either None or a tf.saved_model.SaveOptions')
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
# TODO(rchao): Remove the arg during next breaking release.
@ -1269,9 +1289,10 @@ class ModelCheckpoint(Callback):
self.best, current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
self.model.save_weights(
filepath, overwrite=True, options=self._options)
else:
self.model.save(filepath, overwrite=True)
self.model.save(filepath, overwrite=True, options=self._options)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
@ -1280,9 +1301,10 @@ class ModelCheckpoint(Callback):
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
self.model.save_weights(
filepath, overwrite=True, options=self._options)
else:
self.model.save(filepath, overwrite=True)
self.model.save(filepath, overwrite=True, options=self._options)
self._maybe_remove_file()
except IOError as e:

View File

@ -49,9 +49,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import save_options as save_options_lib
from tensorflow.python.summary import summary_iterator
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
try:
import h5py # pylint:disable=g-import-not-at-top
@ -666,6 +668,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
mode=mode,
save_freq=3)
# Case 9: `ModelCheckpoint` with valid and invalid `options` argument.
with self.assertRaisesRegexp(TypeError, 'tf.train.CheckpointOptions'):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=True,
mode=mode,
options=save_options_lib.SaveOptions())
with self.assertRaisesRegexp(TypeError, 'tf.saved_model.SaveOptions'):
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=False,
mode=mode,
options=checkpoint_options_lib.CheckpointOptions())
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=True,
mode=mode,
options=checkpoint_options_lib.CheckpointOptions())
keras.callbacks.ModelCheckpoint(
filepath,
monitor=monitor,
save_best_only=save_best_only,
save_weights_only=False,
mode=mode,
options=save_options_lib.SaveOptions())
def _get_dummy_resource_for_model_checkpoint_testing(self):
def get_input_datasets():

View File

@ -1979,7 +1979,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
signatures, options)
def save_weights(self, filepath, overwrite=True, save_format=None):
def save_weights(self,
filepath,
overwrite=True,
save_format=None,
options=None):
"""Saves all layer weights.
Either saves in HDF5 or in TensorFlow format based on the `save_format`
@ -2032,6 +2036,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
'.keras' will default to HDF5 if `save_format` is `None`. Otherwise
`None` defaults to 'tf'.
options: Optional `tf.train.CheckpointOptions` object that specifies
options for saving weights.
Raises:
ImportError: If h5py is not available when attempting to save in HDF5
@ -2093,7 +2099,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'the TensorFlow format the optimizer\'s state will not be '
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
% (optimizer,))
self._trackable_saver.save(filepath, session=session)
self._trackable_saver.save(filepath, session=session, options=options)
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(filepath),

View File

@ -302,7 +302,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -320,7 +320,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
}
member_method {
name: "on_batch_begin"

View File

@ -303,7 +303,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -303,7 +303,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -302,7 +302,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -320,7 +320,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -302,7 +302,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -320,7 +320,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
}
member_method {
name: "on_batch_begin"

View File

@ -303,7 +303,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -303,7 +303,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -302,7 +302,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"

View File

@ -320,7 +320,7 @@ tf_class {
}
member_method {
name: "save_weights"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"