Support export eval graph on CPU in TPUEstimator.export_savedmodel().

PiperOrigin-RevId: 209682829
This commit is contained in:
A. Unique TensorFlower 2018-08-21 16:43:07 -07:00 committed by TensorFlower Gardener
parent 30012aeb8c
commit 9941300301
2 changed files with 7 additions and 16 deletions

View File

@ -390,12 +390,6 @@ class _InternalTPUContext(object):
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
return True
if mode != model_fn_lib.ModeKeys.PREDICT:
return False
# There are actually 2 use cases when running with mode.PREDICT: prediction
# and saving the model. We run actual predictions on the TPU, but
# model export is run on the CPU.
if is_export_mode:
return True

View File

@ -2143,9 +2143,10 @@ class TPUEstimator(estimator_lib.Estimator):
mode=model_fn_lib.ModeKeys.PREDICT,
export_tags=None,
check_variables=True):
if mode != model_fn_lib.ModeKeys.PREDICT:
if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
raise NotImplementedError(
'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
'TPUEstimator only handles mode PREDICT for exporting '
'when `export_to_tpu` is `True`; '
'got {}.'.format(mode))
(super(TPUEstimator, self).
@ -2443,16 +2444,12 @@ class TPUEstimator(estimator_lib.Estimator):
with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
if mode != model_fn_lib.ModeKeys.PREDICT:
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
# but not in `export_savedmodel()`.
if self._is_input_fn_invoked:
is_export_mode = False
else:
# For export_savedmodel, input_fn is never passed to Estimator. So, by
# checking the self._is_input_fn_invoked bit, we can know, given the
# mode == PREDICT, it is the .predict API, not export_savedmodel API.
if self._is_input_fn_invoked:
is_export_mode = False
else:
is_export_mode = True
is_export_mode = True
# Clear the bit.
self._is_input_fn_invoked = None