From 994130030131a7b9f8f45bd4e37d1a3d36144d51 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 21 Aug 2018 16:43:07 -0700 Subject: [PATCH] Support export eval graph on CPU in TPUEstimator.export_savedmodel(). PiperOrigin-RevId: 209682829 --- .../contrib/tpu/python/tpu/tpu_context.py | 6 ------ .../contrib/tpu/python/tpu/tpu_estimator.py | 17 +++++++---------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 806ae1c4c99..19359cb6122 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -390,12 +390,6 @@ class _InternalTPUContext(object): logging.info('_is_running_on_cpu: eval_on_tpu disabled') return True - if mode != model_fn_lib.ModeKeys.PREDICT: - return False - - # There are actually 2 use cases when running with mode.PREDICT: prediction - # and saving the model. We run actual predictions on the TPU, but - # model export is run on the CPU. if is_export_mode: return True diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index fed07f00e70..2e4050bd997 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2143,9 +2143,10 @@ class TPUEstimator(estimator_lib.Estimator): mode=model_fn_lib.ModeKeys.PREDICT, export_tags=None, check_variables=True): - if mode != model_fn_lib.ModeKeys.PREDICT: + if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' + 'TPUEstimator only handles mode PREDICT for exporting ' + 'when `export_to_tpu` is `True`; ' 'got {}.'.format(mode)) (super(TPUEstimator, self). @@ -2443,16 +2444,12 @@ class TPUEstimator(estimator_lib.Estimator): with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - if mode != model_fn_lib.ModeKeys.PREDICT: + # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, + # but not in `export_savedmodel()`. + if self._is_input_fn_invoked: is_export_mode = False else: - # For export_savedmodel, input_fn is never passed to Estimator. So, by - # checking the self._is_input_fn_invoked bit, we can know, given the - # mode == PREDICT, it is the .predict API, not export_savedmodel API. - if self._is_input_fn_invoked: - is_export_mode = False - else: - is_export_mode = True + is_export_mode = True # Clear the bit. self._is_input_fn_invoked = None