Add wrapper to help export model trained with estimator as SavedModel for TPU.
PiperOrigin-RevId: 204568222
This commit is contained in:
parent
274cce5990
commit
88b656acd4
@ -42,10 +42,10 @@
|
||||
|
||||
@@TPUEstimator
|
||||
@@TPUEstimatorSpec
|
||||
@@export_estimator_savedmodel
|
||||
@@RunConfig
|
||||
@@InputPipelineConfig
|
||||
@@TPUConfig
|
||||
|
||||
@@bfloat16_scope
|
||||
"""
|
||||
|
||||
|
@ -3320,3 +3320,47 @@ def _add_item_to_params(params, key, value):
|
||||
else:
|
||||
# Now params is Python dict.
|
||||
params[key] = value
|
||||
|
||||
|
||||
def export_estimator_savedmodel(estimator,
|
||||
export_dir_base,
|
||||
serving_input_receiver_fn,
|
||||
assets_extra=None,
|
||||
as_text=False,
|
||||
checkpoint_path=None,
|
||||
strip_default_attrs=False):
|
||||
"""Export `Estimator` trained model for TPU inference.
|
||||
|
||||
Args:
|
||||
estimator: `Estimator` with which model has been trained.
|
||||
export_dir_base: A string containing a directory in which to create
|
||||
timestamped subdirectories containing exported SavedModels.
|
||||
serving_input_receiver_fn: A function that takes no argument and
|
||||
returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
|
||||
assets_extra: A dict specifying how to populate the assets.extra directory
|
||||
within the exported SavedModel, or `None` if no extra assets are needed.
|
||||
as_text: whether to write the SavedModel proto in text format.
|
||||
checkpoint_path: The checkpoint path to export. If `None` (the default),
|
||||
the most recent checkpoint found within the model directory is chosen.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs.
|
||||
|
||||
Returns:
|
||||
The string path to the exported directory.
|
||||
"""
|
||||
# `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
|
||||
# `estimator.config`.
|
||||
config = tpu_config.RunConfig(model_dir=estimator.model_dir)
|
||||
est = TPUEstimator(
|
||||
estimator._model_fn, # pylint: disable=protected-access
|
||||
config=config,
|
||||
params=estimator.params,
|
||||
use_tpu=True,
|
||||
train_batch_size=2048, # Does not matter.
|
||||
eval_batch_size=2048, # Does not matter.
|
||||
)
|
||||
return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
|
||||
assets_extra,
|
||||
as_text,
|
||||
checkpoint_path,
|
||||
strip_default_attrs)
|
||||
|
Loading…
Reference in New Issue
Block a user