Add wrapper to help export model trained with estimator as SavedModel for TPU.

PiperOrigin-RevId: 204568222
This commit is contained in:
A. Unique TensorFlower 2018-07-13 21:31:49 -07:00 committed by TensorFlower Gardener
parent 274cce5990
commit 88b656acd4
2 changed files with 45 additions and 1 deletions

View File

@ -42,10 +42,10 @@
@@TPUEstimator
@@TPUEstimatorSpec
@@export_estimator_savedmodel
@@RunConfig
@@InputPipelineConfig
@@TPUConfig
@@bfloat16_scope
"""

View File

@ -3320,3 +3320,47 @@ def _add_item_to_params(params, key, value):
else:
# Now params is Python dict.
params[key] = value
def export_estimator_savedmodel(estimator,
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False):
"""Export `Estimator` trained model for TPU inference.
Args:
estimator: `Estimator` with which model has been trained.
export_dir_base: A string containing a directory in which to create
timestamped subdirectories containing exported SavedModels.
serving_input_receiver_fn: A function that takes no argument and
returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel, or `None` if no extra assets are needed.
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs.
Returns:
The string path to the exported directory.
"""
# `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
# `estimator.config`.
config = tpu_config.RunConfig(model_dir=estimator.model_dir)
est = TPUEstimator(
estimator._model_fn, # pylint: disable=protected-access
config=config,
params=estimator.params,
use_tpu=True,
train_batch_size=2048, # Does not matter.
eval_batch_size=2048, # Does not matter.
)
return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
assets_extra,
as_text,
checkpoint_path,
strip_default_attrs)