Explicitly disable re-initialize_tpu for MeshTF.
PiperOrigin-RevId: 245790629
This commit is contained in:
parent
d8c82221f0
commit
0251c63f04
@ -484,7 +484,8 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
# is causing the variable corruption since the previous allocated memory
|
||||
# might be overwritten for other purpose.
|
||||
if (ctx.model_parallelism_enabled and
|
||||
ctx.is_input_broadcast_with_iterators()):
|
||||
(ctx.config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.BROADCAST)):
|
||||
self._should_initialize_tpu = False
|
||||
else:
|
||||
self._should_initialize_tpu = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user