Unwrap initial_value if it is a CheckpointInitialValue in collective_all_reduce_strategy's initial_value_fn. This fixes a bug where running keras_mnist_multi_worker with eager causes seg fault.

PiperOrigin-RevId: 259393313
This commit is contained in:
Rick Chao 2019-07-22 13:23:08 -07:00 committed by TensorFlower Gardener
parent 8614aaf955
commit 710d3113bf

View File

@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -335,6 +336,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
if self._num_workers > 1: if self._num_workers > 1:
if self._is_chief: if self._is_chief:
# Unwrap `initial_value` if it is a `CheckpointInitialValue`.
# TODO(b/138130844): Revert the following check once
# `CheckpointInitialValue` class is removed.
if isinstance(initial_value, trackable.CheckpointInitialValue):
initial_value = initial_value.wrapped_value
bcast_send = collective_ops.broadcast_send( bcast_send = collective_ops.broadcast_send(
initial_value, initial_value.shape, initial_value.dtype, initial_value, initial_value.shape, initial_value.dtype,
group_size, group_key, collective_instance_key) group_size, group_key, collective_instance_key)