Automated rollback of commit 19ac5f4f6c
PiperOrigin-RevId: 295802414 Change-Id: I344ec4bb8a0a2cb9921f2f36fa86da9c7f2b55e3
This commit is contained in:
parent
49a83c96b0
commit
36fe0e7aad
@ -1032,7 +1032,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
else:
|
else:
|
||||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||||
# copy from the corresponding replica instead of the primary.
|
# copy from the corresponding replica instead of the primary.
|
||||||
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
index.append(array_ops.identity(all_reduced.primary))
|
||||||
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
|
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
|
||||||
|
|
||||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||||
|
@ -1334,7 +1334,7 @@ class FunctionTest(test.TestCase):
|
|||||||
def forward(x, w, b):
|
def forward(x, w, b):
|
||||||
return x * w + b
|
return x * w + b
|
||||||
x = constant_op.constant([1.0], name="x_useless")
|
x = constant_op.constant([1.0], name="x_useless")
|
||||||
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
|
concrete_forward = forward.get_concrete_function(x, w.primary, b.primary)
|
||||||
|
|
||||||
with ms.scope():
|
with ms.scope():
|
||||||
def replica_fn():
|
def replica_fn():
|
||||||
@ -1350,8 +1350,8 @@ class FunctionTest(test.TestCase):
|
|||||||
g1, g2 = step_fn()
|
g1, g2 = step_fn()
|
||||||
run_metadata = context.export_run_metadata()
|
run_metadata = context.export_run_metadata()
|
||||||
context.disable_run_metadata()
|
context.disable_run_metadata()
|
||||||
self.assertEqual(self.evaluate(g1._primary), 1.0)
|
self.assertEqual(self.evaluate(g1.primary), 1.0)
|
||||||
self.assertEqual(self.evaluate(g2._primary), 1.0)
|
self.assertEqual(self.evaluate(g2.primary), 1.0)
|
||||||
|
|
||||||
# Verify that this node runs on both devices.
|
# Verify that this node runs on both devices.
|
||||||
node_name = "gradients_mul_grad_mul_1_x"
|
node_name = "gradients_mul_grad_mul_1_x"
|
||||||
|
@ -487,7 +487,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _select_fn(x): # pylint: disable=g-missing-docstring
|
def _select_fn(x): # pylint: disable=g-missing-docstring
|
||||||
if isinstance(x, values.Mirrored):
|
if isinstance(x, values.Mirrored):
|
||||||
if len(x.devices) == 1:
|
if len(x.devices) == 1:
|
||||||
return x._primary # pylint: disable=protected-access
|
return x.primary
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot update variable with a Mirrored object with multiple "
|
"You cannot update variable with a Mirrored object with multiple "
|
||||||
|
@ -75,7 +75,7 @@ class DistributedValues(object):
|
|||||||
"replica accesses.")
|
"replica accesses.")
|
||||||
|
|
||||||
def _get_closest(self):
|
def _get_closest(self):
|
||||||
"""Returns value in same replica or device if possible, else the _primary."""
|
"""Returns value in same replica or device if possible, else the primary."""
|
||||||
replica_id = _get_current_replica_id_as_int()
|
replica_id = _get_current_replica_id_as_int()
|
||||||
if replica_id is None:
|
if replica_id is None:
|
||||||
# Try to find a value on the current device.
|
# Try to find a value on the current device.
|
||||||
@ -83,12 +83,12 @@ class DistributedValues(object):
|
|||||||
for value in self._values:
|
for value in self._values:
|
||||||
if device_util.canonicalize(value.device) == current_device:
|
if device_util.canonicalize(value.device) == current_device:
|
||||||
return value
|
return value
|
||||||
return self._primary
|
return self.primary
|
||||||
else:
|
else:
|
||||||
return self._values[replica_id]
|
return self._values[replica_id]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _primary(self):
|
def primary(self):
|
||||||
"""Returns a representative component."""
|
"""Returns a representative component."""
|
||||||
return self._values[0]
|
return self._values[0]
|
||||||
|
|
||||||
@ -368,7 +368,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
def __init__(self, strategy, values):
|
def __init__(self, strategy, values):
|
||||||
self._distribute_strategy = strategy
|
self._distribute_strategy = strategy
|
||||||
super(DistributedVariable, self).__init__(values)
|
super(DistributedVariable, self).__init__(values)
|
||||||
self._common_name = self._primary.name.split(":")[0]
|
self._common_name = self.primary.name.split(":")[0]
|
||||||
# Use a weakref to make it easy to map from the contained values
|
# Use a weakref to make it easy to map from the contained values
|
||||||
# to the container without introducing a reference cycle.
|
# to the container without introducing a reference cycle.
|
||||||
for v in values:
|
for v in values:
|
||||||
@ -395,7 +395,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
The op that evaluates to True or False depending on if all the
|
The op that evaluates to True or False depending on if all the
|
||||||
component variables are initialized.
|
component variables are initialized.
|
||||||
"""
|
"""
|
||||||
result = self._primary.is_initialized()
|
result = self.primary.is_initialized()
|
||||||
# We iterate through the list of values except the last one to allow us to
|
# We iterate through the list of values except the last one to allow us to
|
||||||
# name the final `logical_and` op the same name that is passed by the user
|
# name the final `logical_and` op the same name that is passed by the user
|
||||||
# to the `is_initialized` op. For distributed variables, the
|
# to the `is_initialized` op. For distributed variables, the
|
||||||
@ -426,11 +426,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def constraint(self):
|
def constraint(self):
|
||||||
return self._primary.constraint
|
return self.primary.constraint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph(self):
|
def graph(self):
|
||||||
return self._primary.graph
|
return self.primary.graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _shared_name(self):
|
def _shared_name(self):
|
||||||
@ -438,28 +438,28 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _unique_id(self):
|
def _unique_id(self):
|
||||||
return self._primary._unique_id # pylint: disable=protected-access
|
return self.primary._unique_id # pylint: disable=protected-access
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _graph_key(self):
|
def _graph_key(self):
|
||||||
"""Lets Optimizers know which graph this variable is from."""
|
"""Lets Optimizers know which graph this variable is from."""
|
||||||
return self._primary._graph_key # pylint: disable=protected-access
|
return self.primary._graph_key # pylint: disable=protected-access
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._primary.name
|
return self.primary.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self._primary.dtype
|
return self.primary.dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self._primary.shape
|
return self.primary.shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def synchronization(self):
|
def synchronization(self):
|
||||||
return self._primary.synchronization
|
return self.primary.synchronization
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
@ -475,10 +475,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _save_slice_info(self):
|
def _save_slice_info(self):
|
||||||
return self._primary._save_slice_info # pylint: disable=protected-access
|
return self.primary._save_slice_info # pylint: disable=protected-access
|
||||||
|
|
||||||
def _get_save_slice_info(self):
|
def _get_save_slice_info(self):
|
||||||
return self._primary._get_save_slice_info() # pylint: disable=protected-access
|
return self.primary._get_save_slice_info() # pylint: disable=protected-access
|
||||||
|
|
||||||
def _set_save_slice_info(self, save_slice_info):
|
def _set_save_slice_info(self, save_slice_info):
|
||||||
for v in self._values:
|
for v in self._values:
|
||||||
@ -490,17 +490,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable(self):
|
def trainable(self):
|
||||||
return self._primary.trainable
|
return self.primary.trainable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribute_strategy(self):
|
def distribute_strategy(self):
|
||||||
return self._distribute_strategy
|
return self._distribute_strategy
|
||||||
|
|
||||||
def get_shape(self):
|
def get_shape(self):
|
||||||
return self._primary.get_shape()
|
return self.primary.get_shape()
|
||||||
|
|
||||||
def to_proto(self, export_scope=None):
|
def to_proto(self, export_scope=None):
|
||||||
return self._primary.to_proto(export_scope=export_scope)
|
return self.primary.to_proto(export_scope=export_scope)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def op(self):
|
def op(self):
|
||||||
@ -508,13 +508,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
# to work (even if the current device isn't in self.devices), but
|
# to work (even if the current device isn't in self.devices), but
|
||||||
# other uses of var.op in a cross-replica context to fail.
|
# other uses of var.op in a cross-replica context to fail.
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if distribution_strategy_context.in_cross_replica_context():
|
||||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
|
||||||
self._primary.op.traceback, self._primary.op.type)
|
self.primary.op.traceback, self.primary.op.type)
|
||||||
return self._get().op
|
return self._get().op
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _in_graph_mode(self):
|
def _in_graph_mode(self):
|
||||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
return self.primary._in_graph_mode # pylint: disable=protected-access
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||||
@ -567,7 +567,7 @@ class TPUVariableMixin(object):
|
|||||||
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
||||||
# correctly since in eager mode different variables can have the same name.
|
# correctly since in eager mode different variables can have the same name.
|
||||||
if ops.executing_eagerly_outside_functions():
|
if ops.executing_eagerly_outside_functions():
|
||||||
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
self._handle_id = self._common_name + "_" + str(id(self.primary))
|
||||||
else:
|
else:
|
||||||
self._handle_id = self._common_name
|
self._handle_id = self._common_name
|
||||||
|
|
||||||
@ -592,7 +592,7 @@ class TPUVariableMixin(object):
|
|||||||
if _enclosing_tpu_context() is None:
|
if _enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self)._get_closest()
|
return super(TPUVariableMixin, self)._get_closest()
|
||||||
else:
|
else:
|
||||||
return self._primary
|
return self.primary
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -644,8 +644,8 @@ class TPUVariableMixin(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def op(self):
|
def op(self):
|
||||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
|
||||||
self._primary.op.traceback, self._primary.op.type)
|
self.primary.op.traceback, self.primary.op.type)
|
||||||
|
|
||||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
"""Converts a variable to a tensor."""
|
"""Converts a variable to a tensor."""
|
||||||
@ -900,7 +900,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _saveable_factory(name=self._common_name):
|
def _saveable_factory(name=self._common_name):
|
||||||
return _MirroredSaveable(self, self._primary, name)
|
return _MirroredSaveable(self, self.primary, name)
|
||||||
|
|
||||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||||
|
|
||||||
@ -1003,8 +1003,7 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
|||||||
slice_spec="",
|
slice_spec="",
|
||||||
name=name,
|
name=name,
|
||||||
dtype=sync_on_read_variable.dtype,
|
dtype=sync_on_read_variable.dtype,
|
||||||
device=sync_on_read_variable._primary.device) # pylint: disable=protected-access
|
device=sync_on_read_variable.primary.device)
|
||||||
|
|
||||||
super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
|
super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
@ -1104,7 +1103,7 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
return self._primary
|
return self.primary
|
||||||
|
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return self._distribute_strategy.reduce(
|
return self._distribute_strategy.reduce(
|
||||||
|
@ -274,7 +274,7 @@ class _SaveableView(object):
|
|||||||
self.captured_tensor_node_ids[obj.resource_handle] = node_id
|
self.captured_tensor_node_ids[obj.resource_handle] = node_id
|
||||||
elif (ds_values.is_distributed_variable(obj) or
|
elif (ds_values.is_distributed_variable(obj) or
|
||||||
resource_variable_ops.is_resource_variable(obj)):
|
resource_variable_ops.is_resource_variable(obj)):
|
||||||
obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access
|
obj_to_copy = obj.primary if ds_values.is_distributed_variable(
|
||||||
obj) else obj
|
obj) else obj
|
||||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
|
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
|
||||||
obj_to_copy)
|
obj_to_copy)
|
||||||
|
Loading…
Reference in New Issue
Block a user