(CL 2/2) Allow model_to_estimator
to save object-based checkpoints.
As requested in b/129562822, model_to_estimator should be able to save checkpoints that are compatible with the original Keras model. This allows a user to train their model at scale, then load the trained checkpoint back to the Keras model. This change also allows estimators to warm start from object-based checkpoints. (b/129870586) Part one: cl/247107217 PiperOrigin-RevId: 250519086
This commit is contained in:
parent
c7a1366349
commit
015570dad8
@ -26,13 +26,14 @@ from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
@keras_export('keras.estimator.model_to_estimator')
|
||||
@keras_export(v1=['keras.estimator.model_to_estimator'])
|
||||
def model_to_estimator(
|
||||
keras_model=None,
|
||||
keras_model_path=None,
|
||||
custom_objects=None,
|
||||
model_dir=None,
|
||||
config=None):
|
||||
config=None,
|
||||
checkpoint_format='saver'):
|
||||
"""Constructs an `Estimator` instance from given keras model.
|
||||
|
||||
For usage example, please see:
|
||||
@ -49,6 +50,15 @@ def model_to_estimator(
|
||||
model_dir: Directory to save `Estimator` model parameters, graph, summary
|
||||
files for TensorBoard, etc.
|
||||
config: `RunConfig` to config `Estimator`.
|
||||
checkpoint_format: Sets the format of the checkpoint saved by the estimator
|
||||
when training. May be `saver` or `checkpoint`, depending on whether to
|
||||
save checkpoints from `tf.train.Saver` or `tf.train.Checkpoint`. This
|
||||
argument currently defaults to `saver`. When 2.0 is released, the default
|
||||
will be `checkpoint`. Estimators use name-based `tf.train.Saver`
|
||||
checkpoints, while Keras models use object-based checkpoints from
|
||||
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
|
||||
`model_to_estimator` is only supported by Functional and Sequential
|
||||
models.
|
||||
|
||||
Returns:
|
||||
An Estimator from given keras model.
|
||||
@ -58,6 +68,7 @@ def model_to_estimator(
|
||||
ValueError: if both keras_model and keras_model_path was given.
|
||||
ValueError: if the keras_model_path is a GCS URI.
|
||||
ValueError: if keras_model has not been compiled.
|
||||
ValueError: if an invalid checkpoint_format was given.
|
||||
"""
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
|
||||
@ -65,13 +76,71 @@ def model_to_estimator(
|
||||
raise NotImplementedError(
|
||||
'tf.keras.estimator.model_to_estimator function not available in your '
|
||||
'installation.')
|
||||
return keras_lib.model_to_estimator(
|
||||
return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg
|
||||
keras_model=keras_model,
|
||||
keras_model_path=keras_model_path,
|
||||
custom_objects=custom_objects,
|
||||
model_dir=model_dir,
|
||||
config=config)
|
||||
config=config,
|
||||
checkpoint_format=checkpoint_format)
|
||||
|
||||
|
||||
@keras_export('keras.estimator.model_to_estimator', v1=[])
|
||||
def model_to_estimator_v2(
|
||||
keras_model=None,
|
||||
keras_model_path=None,
|
||||
custom_objects=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
checkpoint_format='checkpoint'):
|
||||
"""Constructs an `Estimator` instance from given keras model.
|
||||
|
||||
For usage example, please see:
|
||||
[Creating estimators from Keras
|
||||
Models](https://tensorflow.org/guide/estimators#model_to_estimator).
|
||||
|
||||
Args:
|
||||
keras_model: A compiled Keras model object. This argument is mutually
|
||||
exclusive with `keras_model_path`.
|
||||
keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
|
||||
format, which can be generated with the `save()` method of a Keras model.
|
||||
This argument is mutually exclusive with `keras_model`.
|
||||
custom_objects: Dictionary for custom objects.
|
||||
model_dir: Directory to save `Estimator` model parameters, graph, summary
|
||||
files for TensorBoard, etc.
|
||||
config: `RunConfig` to config `Estimator`.
|
||||
checkpoint_format: Sets the format of the checkpoint saved by the estimator
|
||||
when training. May be `saver` or `checkpoint`, depending on whether to
|
||||
save checkpoints from `tf.compat.v1.train.Saver` or `tf.train.Checkpoint`.
|
||||
The default is `checkpoint`. Estimators use name-based `tf.train.Saver`
|
||||
checkpoints, while Keras models use object-based checkpoints from
|
||||
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
|
||||
`model_to_estimator` is only supported by Functional and Sequential
|
||||
models.
|
||||
|
||||
Returns:
|
||||
An Estimator from given keras model.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither keras_model nor keras_model_path was given.
|
||||
ValueError: if both keras_model and keras_model_path was given.
|
||||
ValueError: if the keras_model_path is a GCS URI.
|
||||
ValueError: if keras_model has not been compiled.
|
||||
ValueError: if an invalid checkpoint_format was given.
|
||||
"""
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
'tf.keras.estimator.model_to_estimator function not available in your '
|
||||
'installation.')
|
||||
return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg
|
||||
keras_model=keras_model,
|
||||
keras_model_path=keras_model_path,
|
||||
custom_objects=custom_objects,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
checkpoint_format=checkpoint_format)
|
||||
# LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)
|
||||
|
||||
|
||||
|
@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "model_to_estimator"
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'saver\'], "
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "model_to_estimator"
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], "
|
||||
}
|
||||
}
|
||||
|
@ -973,6 +973,12 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
deprecate_partition_strategy_comment,
|
||||
"tf.nn.sampled_softmax_loss":
|
||||
deprecate_partition_strategy_comment,
|
||||
"tf.keras.estimator.model_to_estimator":
|
||||
(ast_edits.WARNING,
|
||||
"Estimators from <function name> will save object-based "
|
||||
"checkpoints (format used by `keras_model.save_weights` and "
|
||||
"`keras_model.load_weights`) by default in 2.0. To continue "
|
||||
"saving name-based checkpoints, set `checkpoint_format='saver'`."),
|
||||
"tf.keras.initializers.Zeros":
|
||||
initializers_no_dtype_comment,
|
||||
"tf.keras.initializers.zeros":
|
||||
|
@ -2087,6 +2087,11 @@ def _log_prob(self, x):
|
||||
self.assertEqual(result_a[3], expected_text_a)
|
||||
self.assertEqual(result_b[3], expected_text_b)
|
||||
|
||||
def test_model_to_estimator_checkpoint_warning(self):
|
||||
text = "tf.keras.estimator.model_to_estimator(model)"
|
||||
_, report, _, _ = self._upgrade(text)
|
||||
expected_info = "will save object-based checkpoints"
|
||||
self.assertIn(expected_info, report)
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user