Unwrap initial_value
if it is a CheckpointInitialValue
in collective_all_reduce_strategy's initial_value_fn
. This fixes a bug where running keras_mnist_multi_worker with eager causes seg fault.
PiperOrigin-RevId: 259393313
This commit is contained in:
parent
8614aaf955
commit
710d3113bf
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import collective_ops
|
from tensorflow.python.ops import collective_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -335,6 +336,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
|
|
||||||
if self._num_workers > 1:
|
if self._num_workers > 1:
|
||||||
if self._is_chief:
|
if self._is_chief:
|
||||||
|
# Unwrap `initial_value` if it is a `CheckpointInitialValue`.
|
||||||
|
# TODO(b/138130844): Revert the following check once
|
||||||
|
# `CheckpointInitialValue` class is removed.
|
||||||
|
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||||
|
initial_value = initial_value.wrapped_value
|
||||||
bcast_send = collective_ops.broadcast_send(
|
bcast_send = collective_ops.broadcast_send(
|
||||||
initial_value, initial_value.shape, initial_value.dtype,
|
initial_value, initial_value.shape, initial_value.dtype,
|
||||||
group_size, group_key, collective_instance_key)
|
group_size, group_key, collective_instance_key)
|
||||||
|
Loading…
Reference in New Issue
Block a user