(CL 2/2) Allow model_to_estimator to save object-based checkpoints.

As requested in b/129562822, model_to_estimator should be able to save checkpoints that are compatible with the original Keras model. This allows a user to train their model at scale, then load the trained checkpoint back to the Keras model.

This change also allows estimators to warm start from object-based checkpoints. (b/129870586)

Part one: cl/247107217

PiperOrigin-RevId: 250519086
This commit is contained in:
Katherine Wu 2019-05-29 10:33:20 -07:00 committed by TensorFlower Gardener
parent c7a1366349
commit 015570dad8
5 changed files with 86 additions and 6 deletions

View File

@ -26,13 +26,14 @@ from tensorflow.python.util.tf_export import keras_export
# LINT.IfChange
@keras_export('keras.estimator.model_to_estimator')
@keras_export(v1=['keras.estimator.model_to_estimator'])
def model_to_estimator(
keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None):
config=None,
checkpoint_format='saver'):
"""Constructs an `Estimator` instance from given keras model.
For usage example, please see:
@ -49,6 +50,15 @@ def model_to_estimator(
model_dir: Directory to save `Estimator` model parameters, graph, summary
files for TensorBoard, etc.
config: `RunConfig` to config `Estimator`.
checkpoint_format: Sets the format of the checkpoint saved by the estimator
when training. May be `saver` or `checkpoint`, depending on whether to
save checkpoints from `tf.train.Saver` or `tf.train.Checkpoint`. This
argument currently defaults to `saver`. When 2.0 is released, the default
will be `checkpoint`. Estimators use name-based `tf.train.Saver`
checkpoints, while Keras models use object-based checkpoints from
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
`model_to_estimator` is only supported by Functional and Sequential
models.
Returns:
An Estimator from given keras model.
@ -58,6 +68,7 @@ def model_to_estimator(
ValueError: if both keras_model and keras_model_path was given.
ValueError: if the keras_model_path is a GCS URI.
ValueError: if keras_model has not been compiled.
ValueError: if an invalid checkpoint_format was given.
"""
try:
from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
@ -65,13 +76,71 @@ def model_to_estimator(
raise NotImplementedError(
'tf.keras.estimator.model_to_estimator function not available in your '
'installation.')
return keras_lib.model_to_estimator(
return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg
keras_model=keras_model,
keras_model_path=keras_model_path,
custom_objects=custom_objects,
model_dir=model_dir,
config=config)
config=config,
checkpoint_format=checkpoint_format)
@keras_export('keras.estimator.model_to_estimator', v1=[])
def model_to_estimator_v2(
keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None,
checkpoint_format='checkpoint'):
"""Constructs an `Estimator` instance from given keras model.
For usage example, please see:
[Creating estimators from Keras
Models](https://tensorflow.org/guide/estimators#model_to_estimator).
Args:
keras_model: A compiled Keras model object. This argument is mutually
exclusive with `keras_model_path`.
keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
format, which can be generated with the `save()` method of a Keras model.
This argument is mutually exclusive with `keras_model`.
custom_objects: Dictionary for custom objects.
model_dir: Directory to save `Estimator` model parameters, graph, summary
files for TensorBoard, etc.
config: `RunConfig` to config `Estimator`.
checkpoint_format: Sets the format of the checkpoint saved by the estimator
when training. May be `saver` or `checkpoint`, depending on whether to
save checkpoints from `tf.compat.v1.train.Saver` or `tf.train.Checkpoint`.
The default is `checkpoint`. Estimators use name-based `tf.train.Saver`
checkpoints, while Keras models use object-based checkpoints from
`tf.train.Checkpoint`. Currently, saving object-based checkpoints from
`model_to_estimator` is only supported by Functional and Sequential
models.
Returns:
An Estimator from given keras model.
Raises:
ValueError: if neither keras_model nor keras_model_path was given.
ValueError: if both keras_model and keras_model_path was given.
ValueError: if the keras_model_path is a GCS URI.
ValueError: if keras_model has not been compiled.
ValueError: if an invalid checkpoint_format was given.
"""
try:
from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
except ImportError:
raise NotImplementedError(
'tf.keras.estimator.model_to_estimator function not available in your '
'installation.')
return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg
keras_model=keras_model,
keras_model_path=keras_model_path,
custom_objects=custom_objects,
model_dir=model_dir,
config=config,
checkpoint_format=checkpoint_format)
# LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)

View File

@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
tf_module {
member_method {
name: "model_to_estimator"
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'saver\'], "
}
}

View File

@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator"
tf_module {
member_method {
name: "model_to_estimator"
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], "
}
}

View File

@ -973,6 +973,12 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
deprecate_partition_strategy_comment,
"tf.nn.sampled_softmax_loss":
deprecate_partition_strategy_comment,
"tf.keras.estimator.model_to_estimator":
(ast_edits.WARNING,
"Estimators from <function name> will save object-based "
"checkpoints (format used by `keras_model.save_weights` and "
"`keras_model.load_weights`) by default in 2.0. To continue "
"saving name-based checkpoints, set `checkpoint_format='saver'`."),
"tf.keras.initializers.Zeros":
initializers_no_dtype_comment,
"tf.keras.initializers.zeros":

View File

@ -2087,6 +2087,11 @@ def _log_prob(self, x):
self.assertEqual(result_a[3], expected_text_a)
self.assertEqual(result_b[3], expected_text_b)
def test_model_to_estimator_checkpoint_warning(self):
text = "tf.keras.estimator.model_to_estimator(model)"
_, report, _, _ = self._upgrade(text)
expected_info = "will save object-based checkpoints"
self.assertIn(expected_info, report)
class TestUpgradeFiles(test_util.TensorFlowTestCase):