Fix breakage related to HParam in tpu_estimator.py

PiperOrigin-RevId: 236415795
This commit is contained in:
A. Unique TensorFlower 2019-03-01 18:56:04 -08:00 committed by TensorFlower Gardener
parent 70da1fe25d
commit 4de282cf8a

View File

@ -2768,10 +2768,10 @@ class TPUEstimator(estimator_lib.Estimator):
if is_export_mode: if is_export_mode:
if mode == _REWRITE_FOR_INFERENCE_MODE: if mode == _REWRITE_FOR_INFERENCE_MODE:
params['use_tpu'] = True _add_item_to_params(params, _USE_TPU_KEY, True)
mode = model_fn_lib.ModeKeys.PREDICT mode = model_fn_lib.ModeKeys.PREDICT
else: else:
params['use_tpu'] = False _add_item_to_params(params, _USE_TPU_KEY, False)
with self._ctx.with_mode(mode) as ctx: with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)