Support export eval graph on CPU in TPUEstimator.export_savedmodel().
PiperOrigin-RevId: 209682829
This commit is contained in:
parent
30012aeb8c
commit
9941300301
@ -390,12 +390,6 @@ class _InternalTPUContext(object):
|
||||
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
||||
return True
|
||||
|
||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
||||
return False
|
||||
|
||||
# There are actually 2 use cases when running with mode.PREDICT: prediction
|
||||
# and saving the model. We run actual predictions on the TPU, but
|
||||
# model export is run on the CPU.
|
||||
if is_export_mode:
|
||||
return True
|
||||
|
||||
|
@ -2143,9 +2143,10 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
mode=model_fn_lib.ModeKeys.PREDICT,
|
||||
export_tags=None,
|
||||
check_variables=True):
|
||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
||||
if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
|
||||
raise NotImplementedError(
|
||||
'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
|
||||
'TPUEstimator only handles mode PREDICT for exporting '
|
||||
'when `export_to_tpu` is `True`; '
|
||||
'got {}.'.format(mode))
|
||||
|
||||
(super(TPUEstimator, self).
|
||||
@ -2443,16 +2444,12 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
with self._ctx.with_mode(mode) as ctx:
|
||||
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
||||
|
||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
||||
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
|
||||
# but not in `export_savedmodel()`.
|
||||
if self._is_input_fn_invoked:
|
||||
is_export_mode = False
|
||||
else:
|
||||
# For export_savedmodel, input_fn is never passed to Estimator. So, by
|
||||
# checking the self._is_input_fn_invoked bit, we can know, given the
|
||||
# mode == PREDICT, it is the .predict API, not export_savedmodel API.
|
||||
if self._is_input_fn_invoked:
|
||||
is_export_mode = False
|
||||
else:
|
||||
is_export_mode = True
|
||||
is_export_mode = True
|
||||
|
||||
# Clear the bit.
|
||||
self._is_input_fn_invoked = None
|
||||
|
Loading…
Reference in New Issue
Block a user