diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py index 2bc5eb2d63c..0c231e29ae9 100644 --- a/tensorflow/python/tpu/tpu_estimator.py +++ b/tensorflow/python/tpu/tpu_estimator.py @@ -484,7 +484,8 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): # is causing the variable corruption since the previous allocated memory # might be overwritten for other purpose. if (ctx.model_parallelism_enabled and - ctx.is_input_broadcast_with_iterators()): + (ctx.config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.BROADCAST)): self._should_initialize_tpu = False else: self._should_initialize_tpu = True