Automated rollback of commit 19ac5f4f6c
PiperOrigin-RevId: 295802414 Change-Id: I344ec4bb8a0a2cb9921f2f36fa86da9c7f2b55e3
This commit is contained in:
parent
49a83c96b0
commit
36fe0e7aad
@ -1032,7 +1032,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
else:
|
||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||
# copy from the corresponding replica instead of the primary.
|
||||
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
||||
index.append(array_ops.identity(all_reduced.primary))
|
||||
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||
|
@ -1334,7 +1334,7 @@ class FunctionTest(test.TestCase):
|
||||
def forward(x, w, b):
|
||||
return x * w + b
|
||||
x = constant_op.constant([1.0], name="x_useless")
|
||||
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
|
||||
concrete_forward = forward.get_concrete_function(x, w.primary, b.primary)
|
||||
|
||||
with ms.scope():
|
||||
def replica_fn():
|
||||
@ -1350,8 +1350,8 @@ class FunctionTest(test.TestCase):
|
||||
g1, g2 = step_fn()
|
||||
run_metadata = context.export_run_metadata()
|
||||
context.disable_run_metadata()
|
||||
self.assertEqual(self.evaluate(g1._primary), 1.0)
|
||||
self.assertEqual(self.evaluate(g2._primary), 1.0)
|
||||
self.assertEqual(self.evaluate(g1.primary), 1.0)
|
||||
self.assertEqual(self.evaluate(g2.primary), 1.0)
|
||||
|
||||
# Verify that this node runs on both devices.
|
||||
node_name = "gradients_mul_grad_mul_1_x"
|
||||
|
@ -487,7 +487,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _select_fn(x): # pylint: disable=g-missing-docstring
|
||||
if isinstance(x, values.Mirrored):
|
||||
if len(x.devices) == 1:
|
||||
return x._primary # pylint: disable=protected-access
|
||||
return x.primary
|
||||
else:
|
||||
raise ValueError(
|
||||
"You cannot update variable with a Mirrored object with multiple "
|
||||
|
@ -75,7 +75,7 @@ class DistributedValues(object):
|
||||
"replica accesses.")
|
||||
|
||||
def _get_closest(self):
|
||||
"""Returns value in same replica or device if possible, else the _primary."""
|
||||
"""Returns value in same replica or device if possible, else the primary."""
|
||||
replica_id = _get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
# Try to find a value on the current device.
|
||||
@ -83,12 +83,12 @@ class DistributedValues(object):
|
||||
for value in self._values:
|
||||
if device_util.canonicalize(value.device) == current_device:
|
||||
return value
|
||||
return self._primary
|
||||
return self.primary
|
||||
else:
|
||||
return self._values[replica_id]
|
||||
|
||||
@property
|
||||
def _primary(self):
|
||||
def primary(self):
|
||||
"""Returns a representative component."""
|
||||
return self._values[0]
|
||||
|
||||
@ -368,7 +368,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
def __init__(self, strategy, values):
|
||||
self._distribute_strategy = strategy
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
self._common_name = self._primary.name.split(":")[0]
|
||||
self._common_name = self.primary.name.split(":")[0]
|
||||
# Use a weakref to make it easy to map from the contained values
|
||||
# to the container without introducing a reference cycle.
|
||||
for v in values:
|
||||
@ -395,7 +395,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
The op that evaluates to True or False depending on if all the
|
||||
component variables are initialized.
|
||||
"""
|
||||
result = self._primary.is_initialized()
|
||||
result = self.primary.is_initialized()
|
||||
# We iterate through the list of values except the last one to allow us to
|
||||
# name the final `logical_and` op the same name that is passed by the user
|
||||
# to the `is_initialized` op. For distributed variables, the
|
||||
@ -426,11 +426,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
|
||||
@property
|
||||
def constraint(self):
|
||||
return self._primary.constraint
|
||||
return self.primary.constraint
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._primary.graph
|
||||
return self.primary.graph
|
||||
|
||||
@property
|
||||
def _shared_name(self):
|
||||
@ -438,28 +438,28 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
|
||||
@property
|
||||
def _unique_id(self):
|
||||
return self._primary._unique_id # pylint: disable=protected-access
|
||||
return self.primary._unique_id # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _graph_key(self):
|
||||
"""Lets Optimizers know which graph this variable is from."""
|
||||
return self._primary._graph_key # pylint: disable=protected-access
|
||||
return self.primary._graph_key # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._primary.name
|
||||
return self.primary.name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._primary.dtype
|
||||
return self.primary.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._primary.shape
|
||||
return self.primary.shape
|
||||
|
||||
@property
|
||||
def synchronization(self):
|
||||
return self._primary.synchronization
|
||||
return self.primary.synchronization
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
@ -475,10 +475,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
|
||||
@property
|
||||
def _save_slice_info(self):
|
||||
return self._primary._save_slice_info # pylint: disable=protected-access
|
||||
return self.primary._save_slice_info # pylint: disable=protected-access
|
||||
|
||||
def _get_save_slice_info(self):
|
||||
return self._primary._get_save_slice_info() # pylint: disable=protected-access
|
||||
return self.primary._get_save_slice_info() # pylint: disable=protected-access
|
||||
|
||||
def _set_save_slice_info(self, save_slice_info):
|
||||
for v in self._values:
|
||||
@ -490,17 +490,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._primary.trainable
|
||||
return self.primary.trainable
|
||||
|
||||
@property
|
||||
def distribute_strategy(self):
|
||||
return self._distribute_strategy
|
||||
|
||||
def get_shape(self):
|
||||
return self._primary.get_shape()
|
||||
return self.primary.get_shape()
|
||||
|
||||
def to_proto(self, export_scope=None):
|
||||
return self._primary.to_proto(export_scope=export_scope)
|
||||
return self.primary.to_proto(export_scope=export_scope)
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
@ -508,13 +508,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
# to work (even if the current device isn't in self.devices), but
|
||||
# other uses of var.op in a cross-replica context to fail.
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
||||
self._primary.op.traceback, self._primary.op.type)
|
||||
return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
|
||||
self.primary.op.traceback, self.primary.op.type)
|
||||
return self._get().op
|
||||
|
||||
@property
|
||||
def _in_graph_mode(self):
|
||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
||||
return self.primary._in_graph_mode # pylint: disable=protected-access
|
||||
|
||||
def read_value(self):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
@ -567,7 +567,7 @@ class TPUVariableMixin(object):
|
||||
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
||||
# correctly since in eager mode different variables can have the same name.
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
||||
self._handle_id = self._common_name + "_" + str(id(self.primary))
|
||||
else:
|
||||
self._handle_id = self._common_name
|
||||
|
||||
@ -592,7 +592,7 @@ class TPUVariableMixin(object):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._get_closest()
|
||||
else:
|
||||
return self._primary
|
||||
return self.primary
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
@ -644,8 +644,8 @@ class TPUVariableMixin(object):
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
||||
self._primary.op.traceback, self._primary.op.type)
|
||||
return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
|
||||
self.primary.op.traceback, self.primary.op.type)
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
@ -900,7 +900,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"""
|
||||
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _MirroredSaveable(self, self._primary, name)
|
||||
return _MirroredSaveable(self, self.primary, name)
|
||||
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
@ -1003,8 +1003,7 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
slice_spec="",
|
||||
name=name,
|
||||
dtype=sync_on_read_variable.dtype,
|
||||
device=sync_on_read_variable._primary.device) # pylint: disable=protected-access
|
||||
|
||||
device=sync_on_read_variable.primary.device)
|
||||
super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
@ -1104,7 +1103,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
def _get_cross_replica(self):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return self._primary
|
||||
return self.primary
|
||||
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
return self._distribute_strategy.reduce(
|
||||
|
@ -274,7 +274,7 @@ class _SaveableView(object):
|
||||
self.captured_tensor_node_ids[obj.resource_handle] = node_id
|
||||
elif (ds_values.is_distributed_variable(obj) or
|
||||
resource_variable_ops.is_resource_variable(obj)):
|
||||
obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access
|
||||
obj_to_copy = obj.primary if ds_values.is_distributed_variable(
|
||||
obj) else obj
|
||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
|
||||
obj_to_copy)
|
||||
|
Loading…
Reference in New Issue
Block a user