Fix breakage related to HParam in tpu_estimator.py
PiperOrigin-RevId: 236415795
This commit is contained in:
parent
70da1fe25d
commit
4de282cf8a
@ -2768,10 +2768,10 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
|
||||
if is_export_mode:
|
||||
if mode == _REWRITE_FOR_INFERENCE_MODE:
|
||||
params['use_tpu'] = True
|
||||
_add_item_to_params(params, _USE_TPU_KEY, True)
|
||||
mode = model_fn_lib.ModeKeys.PREDICT
|
||||
else:
|
||||
params['use_tpu'] = False
|
||||
_add_item_to_params(params, _USE_TPU_KEY, False)
|
||||
|
||||
with self._ctx.with_mode(mode) as ctx:
|
||||
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
||||
|
Loading…
Reference in New Issue
Block a user