Unwrap initial_value
if it is a CheckpointInitialValue
in collective_all_reduce_strategy's initial_value_fn
. This fixes a bug where running keras_mnist_multi_worker with eager causes seg fault.
PiperOrigin-RevId: 259393313
This commit is contained in:
parent
8614aaf955
commit
710d3113bf
@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -335,6 +336,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
if self._num_workers > 1:
|
||||
if self._is_chief:
|
||||
# Unwrap `initial_value` if it is a `CheckpointInitialValue`.
|
||||
# TODO(b/138130844): Revert the following check once
|
||||
# `CheckpointInitialValue` class is removed.
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
initial_value = initial_value.wrapped_value
|
||||
bcast_send = collective_ops.broadcast_send(
|
||||
initial_value, initial_value.shape, initial_value.dtype,
|
||||
group_size, group_key, collective_instance_key)
|
||||
|
Loading…
Reference in New Issue
Block a user