Fix breakage related to HParam in tpu_estimator.py
PiperOrigin-RevId: 236415795
This commit is contained in:
parent
70da1fe25d
commit
4de282cf8a
@ -2768,10 +2768,10 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||||||
|
|
||||||
if is_export_mode:
|
if is_export_mode:
|
||||||
if mode == _REWRITE_FOR_INFERENCE_MODE:
|
if mode == _REWRITE_FOR_INFERENCE_MODE:
|
||||||
params['use_tpu'] = True
|
_add_item_to_params(params, _USE_TPU_KEY, True)
|
||||||
mode = model_fn_lib.ModeKeys.PREDICT
|
mode = model_fn_lib.ModeKeys.PREDICT
|
||||||
else:
|
else:
|
||||||
params['use_tpu'] = False
|
_add_item_to_params(params, _USE_TPU_KEY, False)
|
||||||
|
|
||||||
with self._ctx.with_mode(mode) as ctx:
|
with self._ctx.with_mode(mode) as ctx:
|
||||||
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
|
||||||
|
Loading…
Reference in New Issue
Block a user