Prefer Keras saving API in the MNIST SavedModel example over manual wrapping

for export as reusable SavedModel.

PiperOrigin-RevId: 260925200
This commit is contained in:
A. Unique TensorFlower 2019-07-31 07:48:46 -07:00 committed by TensorFlower Gardener
parent 0fd17699e8
commit f068f55bee
3 changed files with 35 additions and 19 deletions

View File

@ -40,6 +40,11 @@ flags.DEFINE_string(
flags.DEFINE_integer(
'epochs', 10,
'Number of epochs to train.')
flags.DEFINE_bool(
'use_keras_save_api', False,
'Uses tf.keras.models.save_model() on the feature extractor '
'instead of tf.saved_model.save() on a manually wrapped version. '
'With this, the exported model as no hparams.')
flags.DEFINE_bool(
'fast_test_mode', False,
'Shortcut training for running in unit tests.')
@ -180,11 +185,19 @@ def main(argv):
# Save the feature extractor to a framework-agnostic SavedModel for reuse.
# Note that the feature_extractor object has not been compiled or fitted,
# so it does not contain an optimizer and related state.
exportable = wrap_keras_model_for_export(feature_extractor,
(None,) + mnist_util.INPUT_SHAPE,
set_feature_extractor_hparams,
default_hparams)
tf.saved_model.save(exportable, FLAGS.export_dir)
if FLAGS.use_keras_save_api:
# Use Keras' built-in way of creating reusable SavedModels.
# This has no support for adjustable hparams at this time (July 2019).
# (We could also call tf.saved_model.save(feature_extractor, ...),
# point is we're passing a Keras model, not a plain Checkpoint.)
tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
else:
# Assemble a reusable SavedModel manually, with adjustable hparams.
exportable = wrap_keras_model_for_export(feature_extractor,
(None,) + mnist_util.INPUT_SHAPE,
set_feature_extractor_hparams,
default_hparams)
tf.saved_model.save(exportable, FLAGS.export_dir)
if __name__ == '__main__':

View File

@ -74,16 +74,19 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
combinations=(
combinations.combine(
# Test all combinations with tf.saved_model.save().
use_keras_save_api=False,
# Test all combinations using tf.keras.models.save_model()
# for both the reusable and the final full model.
use_keras_save_api=True,
named_strategy=list(ds_utils.named_strategies.values()),
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
) + combinations.combine(
# Test few critcial combinations with tf.keras.models.save_model()
# which is merely a thin wrapper (as of June 2019).
use_keras_save_api=True,
# Test few critcial combinations with raw tf.saved_model.save(),
# including export of a reusable SavedModel that gets assembled
# manually, including support for adjustable hparams.
use_keras_save_api=False,
named_strategy=None,
retrain_flag_value="true",
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
)),
test_combinations=[combinations.NamedGPUCombination()])
@ -102,14 +105,14 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
self.assertCommandSucceeded(
"export_mnist_cnn",
fast_test_mode=fast_test_mode,
export_dir=feature_extrator_dir)
export_dir=feature_extrator_dir,
use_keras_save_api=use_keras_save_api)
use_kwargs = dict(fast_test_mode=fast_test_mode,
input_saved_model_dir=feature_extrator_dir,
retrain=retrain_flag_value,
output_saved_model_dir=full_model_dir,
use_keras_save_api=use_keras_save_api)
if full_model_dir is not None:
use_kwargs["output_saved_model_dir"] = full_model_dir
if named_strategy:
use_kwargs["strategy"] = str(named_strategy)
if regularization_loss_multiplier is not None:
@ -117,11 +120,10 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
"regularization_loss_multiplier"] = regularization_loss_multiplier
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
if full_model_dir is not None:
self.assertCommandSucceeded(
"deploy_mnist_cnn",
fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
self.assertCommandSucceeded(
"deploy_mnist_cnn",
fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
if __name__ == "__main__":

View File

@ -47,7 +47,8 @@ flags.DEFINE_bool(
'If set, the imported SavedModel is trained further.')
flags.DEFINE_float(
'dropout_rate', None,
'If set, dropout rate passed to the SavedModel.')
'If set, dropout rate passed to the SavedModel. '
'Requires a SavedModel with support for adjustable hyperparameters.')
flags.DEFINE_float(
'regularization_loss_multiplier', None,
'If set, multiplier for the regularization losses in the SavedModel.')