diff --git a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py index 6b94fda0f34..f61631a3b62 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py +++ b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py @@ -40,6 +40,11 @@ flags.DEFINE_string( flags.DEFINE_integer( 'epochs', 10, 'Number of epochs to train.') +flags.DEFINE_bool( + 'use_keras_save_api', False, + 'Uses tf.keras.models.save_model() on the feature extractor ' + 'instead of tf.saved_model.save() on a manually wrapped version. ' + 'With this, the exported model as no hparams.') flags.DEFINE_bool( 'fast_test_mode', False, 'Shortcut training for running in unit tests.') @@ -180,11 +185,19 @@ def main(argv): # Save the feature extractor to a framework-agnostic SavedModel for reuse. # Note that the feature_extractor object has not been compiled or fitted, # so it does not contain an optimizer and related state. - exportable = wrap_keras_model_for_export(feature_extractor, - (None,) + mnist_util.INPUT_SHAPE, - set_feature_extractor_hparams, - default_hparams) - tf.saved_model.save(exportable, FLAGS.export_dir) + if FLAGS.use_keras_save_api: + # Use Keras' built-in way of creating reusable SavedModels. + # This has no support for adjustable hparams at this time (July 2019). + # (We could also call tf.saved_model.save(feature_extractor, ...), + # point is we're passing a Keras model, not a plain Checkpoint.) + tf.keras.models.save_model(feature_extractor, FLAGS.export_dir) + else: + # Assemble a reusable SavedModel manually, with adjustable hparams. + exportable = wrap_keras_model_for_export(feature_extractor, + (None,) + mnist_util.INPUT_SHAPE, + set_feature_extractor_hparams, + default_hparams) + tf.saved_model.save(exportable, FLAGS.export_dir) if __name__ == '__main__': diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py index 5c198d864c7..232a5b5e1ba 100644 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py @@ -74,16 +74,19 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase): combinations=( combinations.combine( # Test all combinations with tf.saved_model.save(). - use_keras_save_api=False, + # Test all combinations using tf.keras.models.save_model() + # for both the reusable and the final full model. + use_keras_save_api=True, named_strategy=list(ds_utils.named_strategies.values()), retrain_flag_value=["true", "false"], regularization_loss_multiplier=[None, 2], # Test for b/134528831. ) + combinations.combine( - # Test few critcial combinations with tf.keras.models.save_model() - # which is merely a thin wrapper (as of June 2019). - use_keras_save_api=True, + # Test few critcial combinations with raw tf.saved_model.save(), + # including export of a reusable SavedModel that gets assembled + # manually, including support for adjustable hparams. + use_keras_save_api=False, named_strategy=None, - retrain_flag_value="true", + retrain_flag_value=["true", "false"], regularization_loss_multiplier=[None, 2], # Test for b/134528831. )), test_combinations=[combinations.NamedGPUCombination()]) @@ -102,14 +105,14 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase): self.assertCommandSucceeded( "export_mnist_cnn", fast_test_mode=fast_test_mode, - export_dir=feature_extrator_dir) + export_dir=feature_extrator_dir, + use_keras_save_api=use_keras_save_api) use_kwargs = dict(fast_test_mode=fast_test_mode, input_saved_model_dir=feature_extrator_dir, retrain=retrain_flag_value, + output_saved_model_dir=full_model_dir, use_keras_save_api=use_keras_save_api) - if full_model_dir is not None: - use_kwargs["output_saved_model_dir"] = full_model_dir if named_strategy: use_kwargs["strategy"] = str(named_strategy) if regularization_loss_multiplier is not None: @@ -117,11 +120,10 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase): "regularization_loss_multiplier"] = regularization_loss_multiplier self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs) - if full_model_dir is not None: - self.assertCommandSucceeded( - "deploy_mnist_cnn", - fast_test_mode=fast_test_mode, - saved_model_dir=full_model_dir) + self.assertCommandSucceeded( + "deploy_mnist_cnn", + fast_test_mode=fast_test_mode, + saved_model_dir=full_model_dir) if __name__ == "__main__": diff --git a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py index 24d1be4aa50..ae45a02a59b 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py +++ b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py @@ -47,7 +47,8 @@ flags.DEFINE_bool( 'If set, the imported SavedModel is trained further.') flags.DEFINE_float( 'dropout_rate', None, - 'If set, dropout rate passed to the SavedModel.') + 'If set, dropout rate passed to the SavedModel. ' + 'Requires a SavedModel with support for adjustable hyperparameters.') flags.DEFINE_float( 'regularization_loss_multiplier', None, 'If set, multiplier for the regularization losses in the SavedModel.')