Support export eval graph on CPU in TPUEstimator.export_savedmodel().
PiperOrigin-RevId: 209682829
This commit is contained in:
parent
30012aeb8c
commit
9941300301
@ -390,12 +390,6 @@ class _InternalTPUContext(object):
|
|||||||
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# There are actually 2 use cases when running with mode.PREDICT: prediction
|
|
||||||
# and saving the model. We run actual predictions on the TPU, but
|
|
||||||
# model export is run on the CPU.
|
|
||||||
if is_export_mode:
|
if is_export_mode:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -2143,9 +2143,10 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
mode=model_fn_lib.ModeKeys.PREDICT,
|
mode=model_fn_lib.ModeKeys.PREDICT,
|
||||||
export_tags=None,
|
export_tags=None,
|
||||||
check_variables=True):
|
check_variables=True):
|
||||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
|
'TPUEstimator only handles mode PREDICT for exporting '
|
||||||
|
'when `export_to_tpu` is `True`; '
|
||||||
'got {}.'.format(mode))
|
'got {}.'.format(mode))
|
||||||
|
|
||||||
(super(TPUEstimator, self).
|
(super(TPUEstimator, self).
|
||||||
@ -2443,16 +2444,12 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
with self._ctx.with_mode(mode) as ctx:
|
with self._ctx.with_mode(mode) as ctx:
|
||||||
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
||||||
|
|
||||||
if mode != model_fn_lib.ModeKeys.PREDICT:
|
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
|
||||||
|
# but not in `export_savedmodel()`.
|
||||||
|
if self._is_input_fn_invoked:
|
||||||
is_export_mode = False
|
is_export_mode = False
|
||||||
else:
|
else:
|
||||||
# For export_savedmodel, input_fn is never passed to Estimator. So, by
|
is_export_mode = True
|
||||||
# checking the self._is_input_fn_invoked bit, we can know, given the
|
|
||||||
# mode == PREDICT, it is the .predict API, not export_savedmodel API.
|
|
||||||
if self._is_input_fn_invoked:
|
|
||||||
is_export_mode = False
|
|
||||||
else:
|
|
||||||
is_export_mode = True
|
|
||||||
|
|
||||||
# Clear the bit.
|
# Clear the bit.
|
||||||
self._is_input_fn_invoked = None
|
self._is_input_fn_invoked = None
|
||||||
|
Loading…
Reference in New Issue
Block a user