Add SaveOptions/CheckpointOptions to keras.Models.save_weights and keras_call_backs.ModelCheckpoint.
PiperOrigin-RevId: 316973333 Change-Id: I43f5b59ece4b862db41ab0e99f3c8df0a0d3b901
This commit is contained in:
parent
ed2b3d6e1e
commit
7a92859246
@ -54,7 +54,9 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.saved_model import save_options as save_options_lib
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -1115,6 +1117,9 @@ class ModelCheckpoint(Callback):
|
||||
epochs, the monitored metric may potentially be less reliable (it
|
||||
could reflect as little as 1 batch, since the metrics get reset every
|
||||
epoch). Defaults to `'epoch'`.
|
||||
options: Optional `tf.train.CheckpointOptions` object if
|
||||
`save_weights_only` is true or optional `tf.saved_model.SavedOptions`
|
||||
object if `save_weights_only` is false.
|
||||
**kwargs: Additional arguments for backwards compatibility. Possible key
|
||||
is `period`.
|
||||
"""
|
||||
@ -1127,6 +1132,7 @@ class ModelCheckpoint(Callback):
|
||||
save_weights_only=False,
|
||||
mode='auto',
|
||||
save_freq='epoch',
|
||||
options=None,
|
||||
**kwargs):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self._supports_tf_logs = True
|
||||
@ -1140,6 +1146,20 @@ class ModelCheckpoint(Callback):
|
||||
self._batches_seen_since_last_saving = 0
|
||||
self._last_batch_seen = 0
|
||||
|
||||
if save_weights_only:
|
||||
if options is None or isinstance(
|
||||
options, checkpoint_options_lib.CheckpointOptions):
|
||||
self._options = options or checkpoint_options_lib.CheckpointOptions()
|
||||
else:
|
||||
raise TypeError('If save_weights_only is True, then `options` must be'
|
||||
'either None or a tf.train.CheckpointOptions')
|
||||
else:
|
||||
if options is None or isinstance(options, save_options_lib.SaveOptions):
|
||||
self._options = options or save_options_lib.SaveOptions()
|
||||
else:
|
||||
raise TypeError('If save_weights_only is False, then `options` must be'
|
||||
'either None or a tf.saved_model.SaveOptions')
|
||||
|
||||
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
|
||||
# file from `filepath` at the start of `model.fit()`
|
||||
# TODO(rchao): Remove the arg during next breaking release.
|
||||
@ -1269,9 +1289,10 @@ class ModelCheckpoint(Callback):
|
||||
self.best, current, filepath))
|
||||
self.best = current
|
||||
if self.save_weights_only:
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
self.model.save_weights(
|
||||
filepath, overwrite=True, options=self._options)
|
||||
else:
|
||||
self.model.save(filepath, overwrite=True)
|
||||
self.model.save(filepath, overwrite=True, options=self._options)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print('\nEpoch %05d: %s did not improve from %0.5f' %
|
||||
@ -1280,9 +1301,10 @@ class ModelCheckpoint(Callback):
|
||||
if self.verbose > 0:
|
||||
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
|
||||
if self.save_weights_only:
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
self.model.save_weights(
|
||||
filepath, overwrite=True, options=self._options)
|
||||
else:
|
||||
self.model.save(filepath, overwrite=True)
|
||||
self.model.save(filepath, overwrite=True, options=self._options)
|
||||
|
||||
self._maybe_remove_file()
|
||||
except IOError as e:
|
||||
|
@ -49,9 +49,11 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import save_options as save_options_lib
|
||||
from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
@ -666,6 +668,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
mode=mode,
|
||||
save_freq=3)
|
||||
|
||||
# Case 9: `ModelCheckpoint` with valid and invalid `options` argument.
|
||||
with self.assertRaisesRegexp(TypeError, 'tf.train.CheckpointOptions'):
|
||||
keras.callbacks.ModelCheckpoint(
|
||||
filepath,
|
||||
monitor=monitor,
|
||||
save_best_only=save_best_only,
|
||||
save_weights_only=True,
|
||||
mode=mode,
|
||||
options=save_options_lib.SaveOptions())
|
||||
with self.assertRaisesRegexp(TypeError, 'tf.saved_model.SaveOptions'):
|
||||
keras.callbacks.ModelCheckpoint(
|
||||
filepath,
|
||||
monitor=monitor,
|
||||
save_best_only=save_best_only,
|
||||
save_weights_only=False,
|
||||
mode=mode,
|
||||
options=checkpoint_options_lib.CheckpointOptions())
|
||||
keras.callbacks.ModelCheckpoint(
|
||||
filepath,
|
||||
monitor=monitor,
|
||||
save_best_only=save_best_only,
|
||||
save_weights_only=True,
|
||||
mode=mode,
|
||||
options=checkpoint_options_lib.CheckpointOptions())
|
||||
keras.callbacks.ModelCheckpoint(
|
||||
filepath,
|
||||
monitor=monitor,
|
||||
save_best_only=save_best_only,
|
||||
save_weights_only=False,
|
||||
mode=mode,
|
||||
options=save_options_lib.SaveOptions())
|
||||
|
||||
def _get_dummy_resource_for_model_checkpoint_testing(self):
|
||||
|
||||
def get_input_datasets():
|
||||
|
@ -1979,7 +1979,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
|
||||
signatures, options)
|
||||
|
||||
def save_weights(self, filepath, overwrite=True, save_format=None):
|
||||
def save_weights(self,
|
||||
filepath,
|
||||
overwrite=True,
|
||||
save_format=None,
|
||||
options=None):
|
||||
"""Saves all layer weights.
|
||||
|
||||
Either saves in HDF5 or in TensorFlow format based on the `save_format`
|
||||
@ -2032,6 +2036,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
|
||||
'.keras' will default to HDF5 if `save_format` is `None`. Otherwise
|
||||
`None` defaults to 'tf'.
|
||||
options: Optional `tf.train.CheckpointOptions` object that specifies
|
||||
options for saving weights.
|
||||
|
||||
Raises:
|
||||
ImportError: If h5py is not available when attempting to save in HDF5
|
||||
@ -2093,7 +2099,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
'the TensorFlow format the optimizer\'s state will not be '
|
||||
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
|
||||
% (optimizer,))
|
||||
self._trackable_saver.save(filepath, session=session)
|
||||
self._trackable_saver.save(filepath, session=session, options=options)
|
||||
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=os.path.dirname(filepath),
|
||||
|
@ -302,7 +302,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -320,7 +320,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_batch_begin"
|
||||
|
@ -303,7 +303,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -303,7 +303,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -302,7 +302,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -320,7 +320,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -302,7 +302,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -320,7 +320,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_batch_begin"
|
||||
|
@ -303,7 +303,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -303,7 +303,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -302,7 +302,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
@ -320,7 +320,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
|
Loading…
Reference in New Issue
Block a user