Prefer Keras saving API in the MNIST SavedModel example over manual wrapping
for export as reusable SavedModel. PiperOrigin-RevId: 260925200
This commit is contained in:
parent
0fd17699e8
commit
f068f55bee
@ -40,6 +40,11 @@ flags.DEFINE_string(
|
||||
flags.DEFINE_integer(
|
||||
'epochs', 10,
|
||||
'Number of epochs to train.')
|
||||
flags.DEFINE_bool(
|
||||
'use_keras_save_api', False,
|
||||
'Uses tf.keras.models.save_model() on the feature extractor '
|
||||
'instead of tf.saved_model.save() on a manually wrapped version. '
|
||||
'With this, the exported model as no hparams.')
|
||||
flags.DEFINE_bool(
|
||||
'fast_test_mode', False,
|
||||
'Shortcut training for running in unit tests.')
|
||||
@ -180,11 +185,19 @@ def main(argv):
|
||||
# Save the feature extractor to a framework-agnostic SavedModel for reuse.
|
||||
# Note that the feature_extractor object has not been compiled or fitted,
|
||||
# so it does not contain an optimizer and related state.
|
||||
exportable = wrap_keras_model_for_export(feature_extractor,
|
||||
(None,) + mnist_util.INPUT_SHAPE,
|
||||
set_feature_extractor_hparams,
|
||||
default_hparams)
|
||||
tf.saved_model.save(exportable, FLAGS.export_dir)
|
||||
if FLAGS.use_keras_save_api:
|
||||
# Use Keras' built-in way of creating reusable SavedModels.
|
||||
# This has no support for adjustable hparams at this time (July 2019).
|
||||
# (We could also call tf.saved_model.save(feature_extractor, ...),
|
||||
# point is we're passing a Keras model, not a plain Checkpoint.)
|
||||
tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
|
||||
else:
|
||||
# Assemble a reusable SavedModel manually, with adjustable hparams.
|
||||
exportable = wrap_keras_model_for_export(feature_extractor,
|
||||
(None,) + mnist_util.INPUT_SHAPE,
|
||||
set_feature_extractor_hparams,
|
||||
default_hparams)
|
||||
tf.saved_model.save(exportable, FLAGS.export_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -74,16 +74,19 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
combinations=(
|
||||
combinations.combine(
|
||||
# Test all combinations with tf.saved_model.save().
|
||||
use_keras_save_api=False,
|
||||
# Test all combinations using tf.keras.models.save_model()
|
||||
# for both the reusable and the final full model.
|
||||
use_keras_save_api=True,
|
||||
named_strategy=list(ds_utils.named_strategies.values()),
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||
) + combinations.combine(
|
||||
# Test few critcial combinations with tf.keras.models.save_model()
|
||||
# which is merely a thin wrapper (as of June 2019).
|
||||
use_keras_save_api=True,
|
||||
# Test few critcial combinations with raw tf.saved_model.save(),
|
||||
# including export of a reusable SavedModel that gets assembled
|
||||
# manually, including support for adjustable hparams.
|
||||
use_keras_save_api=False,
|
||||
named_strategy=None,
|
||||
retrain_flag_value="true",
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||
)),
|
||||
test_combinations=[combinations.NamedGPUCombination()])
|
||||
@ -102,14 +105,14 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
self.assertCommandSucceeded(
|
||||
"export_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
export_dir=feature_extrator_dir)
|
||||
export_dir=feature_extrator_dir,
|
||||
use_keras_save_api=use_keras_save_api)
|
||||
|
||||
use_kwargs = dict(fast_test_mode=fast_test_mode,
|
||||
input_saved_model_dir=feature_extrator_dir,
|
||||
retrain=retrain_flag_value,
|
||||
output_saved_model_dir=full_model_dir,
|
||||
use_keras_save_api=use_keras_save_api)
|
||||
if full_model_dir is not None:
|
||||
use_kwargs["output_saved_model_dir"] = full_model_dir
|
||||
if named_strategy:
|
||||
use_kwargs["strategy"] = str(named_strategy)
|
||||
if regularization_loss_multiplier is not None:
|
||||
@ -117,11 +120,10 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
"regularization_loss_multiplier"] = regularization_loss_multiplier
|
||||
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
|
||||
|
||||
if full_model_dir is not None:
|
||||
self.assertCommandSucceeded(
|
||||
"deploy_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
saved_model_dir=full_model_dir)
|
||||
self.assertCommandSucceeded(
|
||||
"deploy_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
saved_model_dir=full_model_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -47,7 +47,8 @@ flags.DEFINE_bool(
|
||||
'If set, the imported SavedModel is trained further.')
|
||||
flags.DEFINE_float(
|
||||
'dropout_rate', None,
|
||||
'If set, dropout rate passed to the SavedModel.')
|
||||
'If set, dropout rate passed to the SavedModel. '
|
||||
'Requires a SavedModel with support for adjustable hyperparameters.')
|
||||
flags.DEFINE_float(
|
||||
'regularization_loss_multiplier', None,
|
||||
'If set, multiplier for the regularization losses in the SavedModel.')
|
||||
|
Loading…
Reference in New Issue
Block a user