diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py index f0e04c066d6..ccb2b74351a 100644 --- a/tensorflow/python/keras/estimator/__init__.py +++ b/tensorflow/python/keras/estimator/__init__.py @@ -26,13 +26,14 @@ from tensorflow.python.util.tf_export import keras_export # LINT.IfChange -@keras_export('keras.estimator.model_to_estimator') +@keras_export(v1=['keras.estimator.model_to_estimator']) def model_to_estimator( keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None, - config=None): + config=None, + checkpoint_format='saver'): """Constructs an `Estimator` instance from given keras model. For usage example, please see: @@ -49,6 +50,15 @@ def model_to_estimator( model_dir: Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc. config: `RunConfig` to config `Estimator`. + checkpoint_format: Sets the format of the checkpoint saved by the estimator + when training. May be `saver` or `checkpoint`, depending on whether to + save checkpoints from `tf.train.Saver` or `tf.train.Checkpoint`. This + argument currently defaults to `saver`. When 2.0 is released, the default + will be `checkpoint`. Estimators use name-based `tf.train.Saver` + checkpoints, while Keras models use object-based checkpoints from + `tf.train.Checkpoint`. Currently, saving object-based checkpoints from + `model_to_estimator` is only supported by Functional and Sequential + models. Returns: An Estimator from given keras model. @@ -58,6 +68,7 @@ def model_to_estimator( ValueError: if both keras_model and keras_model_path was given. ValueError: if the keras_model_path is a GCS URI. ValueError: if keras_model has not been compiled. + ValueError: if an invalid checkpoint_format was given. """ try: from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top @@ -65,13 +76,71 @@ def model_to_estimator( raise NotImplementedError( 'tf.keras.estimator.model_to_estimator function not available in your ' 'installation.') - return keras_lib.model_to_estimator( + return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg keras_model=keras_model, keras_model_path=keras_model_path, custom_objects=custom_objects, model_dir=model_dir, - config=config) + config=config, + checkpoint_format=checkpoint_format) + +@keras_export('keras.estimator.model_to_estimator', v1=[]) +def model_to_estimator_v2( + keras_model=None, + keras_model_path=None, + custom_objects=None, + model_dir=None, + config=None, + checkpoint_format='checkpoint'): + """Constructs an `Estimator` instance from given keras model. + + For usage example, please see: + [Creating estimators from Keras + Models](https://tensorflow.org/guide/estimators#model_to_estimator). + + Args: + keras_model: A compiled Keras model object. This argument is mutually + exclusive with `keras_model_path`. + keras_model_path: Path to a compiled Keras model saved on disk, in HDF5 + format, which can be generated with the `save()` method of a Keras model. + This argument is mutually exclusive with `keras_model`. + custom_objects: Dictionary for custom objects. + model_dir: Directory to save `Estimator` model parameters, graph, summary + files for TensorBoard, etc. + config: `RunConfig` to config `Estimator`. + checkpoint_format: Sets the format of the checkpoint saved by the estimator + when training. May be `saver` or `checkpoint`, depending on whether to + save checkpoints from `tf.compat.v1.train.Saver` or `tf.train.Checkpoint`. + The default is `checkpoint`. Estimators use name-based `tf.train.Saver` + checkpoints, while Keras models use object-based checkpoints from + `tf.train.Checkpoint`. Currently, saving object-based checkpoints from + `model_to_estimator` is only supported by Functional and Sequential + models. + + Returns: + An Estimator from given keras model. + + Raises: + ValueError: if neither keras_model nor keras_model_path was given. + ValueError: if both keras_model and keras_model_path was given. + ValueError: if the keras_model_path is a GCS URI. + ValueError: if keras_model has not been compiled. + ValueError: if an invalid checkpoint_format was given. + """ + try: + from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top + except ImportError: + raise NotImplementedError( + 'tf.keras.estimator.model_to_estimator function not available in your ' + 'installation.') + return keras_lib.model_to_estimator( # pylint:disable=unexpected-keyword-arg + keras_model=keras_model, + keras_model_path=keras_model_path, + custom_objects=custom_objects, + model_dir=model_dir, + config=config, + checkpoint_format=checkpoint_format) # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt index 7a3fb39f774..d0dca9a5a31 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt @@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator" tf_module { member_method { name: "model_to_estimator" - argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'saver\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt index 7a3fb39f774..81fcfd87cda 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt @@ -2,6 +2,6 @@ path: "tensorflow.keras.estimator" tf_module { member_method { name: "model_to_estimator" - argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\', \'checkpoint_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'checkpoint\'], " } } diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 24c55d2113f..c834d84e70b 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -973,6 +973,12 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec): deprecate_partition_strategy_comment, "tf.nn.sampled_softmax_loss": deprecate_partition_strategy_comment, + "tf.keras.estimator.model_to_estimator": + (ast_edits.WARNING, + "Estimators from will save object-based " + "checkpoints (format used by `keras_model.save_weights` and " + "`keras_model.load_weights`) by default in 2.0. To continue " + "saving name-based checkpoints, set `checkpoint_format='saver'`."), "tf.keras.initializers.Zeros": initializers_no_dtype_comment, "tf.keras.initializers.zeros": diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index ab292ccf712..e12bc54cf8d 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -2087,6 +2087,11 @@ def _log_prob(self, x): self.assertEqual(result_a[3], expected_text_a) self.assertEqual(result_b[3], expected_text_b) + def test_model_to_estimator_checkpoint_warning(self): + text = "tf.keras.estimator.model_to_estimator(model)" + _, report, _, _ = self._upgrade(text) + expected_info = "will save object-based checkpoints" + self.assertIn(expected_info, report) class TestUpgradeFiles(test_util.TensorFlowTestCase):