From 710d3113bf63558aa8a0faccab9cdb562052692e Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Mon, 22 Jul 2019 13:23:08 -0700 Subject: [PATCH] Unwrap `initial_value` if it is a `CheckpointInitialValue` in collective_all_reduce_strategy's `initial_value_fn`. This fixes a bug where running keras_mnist_multi_worker with eager causes seg fault. PiperOrigin-RevId: 259393313 --- .../python/distribute/collective_all_reduce_strategy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index e35f95a0331..c43d28b0226 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util.tf_export import tf_export @@ -335,6 +336,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if self._num_workers > 1: if self._is_chief: + # Unwrap `initial_value` if it is a `CheckpointInitialValue`. + # TODO(b/138130844): Revert the following check once + # `CheckpointInitialValue` class is removed. + if isinstance(initial_value, trackable.CheckpointInitialValue): + initial_value = initial_value.wrapped_value bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key)