Explicitly disable re-initialize_tpu for MeshTF.

PiperOrigin-RevId: 245790629
This commit is contained in:
Youlong Cheng 2019-04-29 11:43:18 -07:00 committed by TensorFlower Gardener
parent d8c82221f0
commit 0251c63f04

View File

@ -484,7 +484,8 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
# is causing the variable corruption since the previous allocated memory
# might be overwritten for other purpose.
if (ctx.model_parallelism_enabled and
ctx.is_input_broadcast_with_iterators()):
(ctx.config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.BROADCAST)):
self._should_initialize_tpu = False
else:
self._should_initialize_tpu = True