Automated rollback of commit 19ac5f4f6c

PiperOrigin-RevId: 295802414
Change-Id: I344ec4bb8a0a2cb9921f2f36fa86da9c7f2b55e3
This commit is contained in:
A. Unique TensorFlower 2020-02-18 13:01:54 -08:00 committed by TensorFlower Gardener
parent 49a83c96b0
commit 36fe0e7aad
5 changed files with 34 additions and 35 deletions

View File

@ -1032,7 +1032,7 @@ class CollectiveAllReduce(CrossDeviceOps):
else: else:
# TODO(josh11b): Once we add support for model parallelism, get the # TODO(josh11b): Once we add support for model parallelism, get the
# copy from the corresponding replica instead of the primary. # copy from the corresponding replica instead of the primary.
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access index.append(array_ops.identity(all_reduced.primary))
return value_lib.regroup(index, wrap_class=value_lib.Mirrored) return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs): def batch_reduce_implementation(self, reduce_op, value_destination_pairs):

View File

@ -1334,7 +1334,7 @@ class FunctionTest(test.TestCase):
def forward(x, w, b): def forward(x, w, b):
return x * w + b return x * w + b
x = constant_op.constant([1.0], name="x_useless") x = constant_op.constant([1.0], name="x_useless")
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) concrete_forward = forward.get_concrete_function(x, w.primary, b.primary)
with ms.scope(): with ms.scope():
def replica_fn(): def replica_fn():
@ -1350,8 +1350,8 @@ class FunctionTest(test.TestCase):
g1, g2 = step_fn() g1, g2 = step_fn()
run_metadata = context.export_run_metadata() run_metadata = context.export_run_metadata()
context.disable_run_metadata() context.disable_run_metadata()
self.assertEqual(self.evaluate(g1._primary), 1.0) self.assertEqual(self.evaluate(g1.primary), 1.0)
self.assertEqual(self.evaluate(g2._primary), 1.0) self.assertEqual(self.evaluate(g2.primary), 1.0)
# Verify that this node runs on both devices. # Verify that this node runs on both devices.
node_name = "gradients_mul_grad_mul_1_x" node_name = "gradients_mul_grad_mul_1_x"

View File

@ -487,7 +487,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
def _select_fn(x): # pylint: disable=g-missing-docstring def _select_fn(x): # pylint: disable=g-missing-docstring
if isinstance(x, values.Mirrored): if isinstance(x, values.Mirrored):
if len(x.devices) == 1: if len(x.devices) == 1:
return x._primary # pylint: disable=protected-access return x.primary
else: else:
raise ValueError( raise ValueError(
"You cannot update variable with a Mirrored object with multiple " "You cannot update variable with a Mirrored object with multiple "

View File

@ -75,7 +75,7 @@ class DistributedValues(object):
"replica accesses.") "replica accesses.")
def _get_closest(self): def _get_closest(self):
"""Returns value in same replica or device if possible, else the _primary.""" """Returns value in same replica or device if possible, else the primary."""
replica_id = _get_current_replica_id_as_int() replica_id = _get_current_replica_id_as_int()
if replica_id is None: if replica_id is None:
# Try to find a value on the current device. # Try to find a value on the current device.
@ -83,12 +83,12 @@ class DistributedValues(object):
for value in self._values: for value in self._values:
if device_util.canonicalize(value.device) == current_device: if device_util.canonicalize(value.device) == current_device:
return value return value
return self._primary return self.primary
else: else:
return self._values[replica_id] return self._values[replica_id]
@property @property
def _primary(self): def primary(self):
"""Returns a representative component.""" """Returns a representative component."""
return self._values[0] return self._values[0]
@ -368,7 +368,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
def __init__(self, strategy, values): def __init__(self, strategy, values):
self._distribute_strategy = strategy self._distribute_strategy = strategy
super(DistributedVariable, self).__init__(values) super(DistributedVariable, self).__init__(values)
self._common_name = self._primary.name.split(":")[0] self._common_name = self.primary.name.split(":")[0]
# Use a weakref to make it easy to map from the contained values # Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle. # to the container without introducing a reference cycle.
for v in values: for v in values:
@ -395,7 +395,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
The op that evaluates to True or False depending on if all the The op that evaluates to True or False depending on if all the
component variables are initialized. component variables are initialized.
""" """
result = self._primary.is_initialized() result = self.primary.is_initialized()
# We iterate through the list of values except the last one to allow us to # We iterate through the list of values except the last one to allow us to
# name the final `logical_and` op the same name that is passed by the user # name the final `logical_and` op the same name that is passed by the user
# to the `is_initialized` op. For distributed variables, the # to the `is_initialized` op. For distributed variables, the
@ -426,11 +426,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
@property @property
def constraint(self): def constraint(self):
return self._primary.constraint return self.primary.constraint
@property @property
def graph(self): def graph(self):
return self._primary.graph return self.primary.graph
@property @property
def _shared_name(self): def _shared_name(self):
@ -438,28 +438,28 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
@property @property
def _unique_id(self): def _unique_id(self):
return self._primary._unique_id # pylint: disable=protected-access return self.primary._unique_id # pylint: disable=protected-access
@property @property
def _graph_key(self): def _graph_key(self):
"""Lets Optimizers know which graph this variable is from.""" """Lets Optimizers know which graph this variable is from."""
return self._primary._graph_key # pylint: disable=protected-access return self.primary._graph_key # pylint: disable=protected-access
@property @property
def name(self): def name(self):
return self._primary.name return self.primary.name
@property @property
def dtype(self): def dtype(self):
return self._primary.dtype return self.primary.dtype
@property @property
def shape(self): def shape(self):
return self._primary.shape return self.primary.shape
@property @property
def synchronization(self): def synchronization(self):
return self._primary.synchronization return self.primary.synchronization
@property @property
def handle(self): def handle(self):
@ -475,10 +475,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
@property @property
def _save_slice_info(self): def _save_slice_info(self):
return self._primary._save_slice_info # pylint: disable=protected-access return self.primary._save_slice_info # pylint: disable=protected-access
def _get_save_slice_info(self): def _get_save_slice_info(self):
return self._primary._get_save_slice_info() # pylint: disable=protected-access return self.primary._get_save_slice_info() # pylint: disable=protected-access
def _set_save_slice_info(self, save_slice_info): def _set_save_slice_info(self, save_slice_info):
for v in self._values: for v in self._values:
@ -490,17 +490,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
@property @property
def trainable(self): def trainable(self):
return self._primary.trainable return self.primary.trainable
@property @property
def distribute_strategy(self): def distribute_strategy(self):
return self._distribute_strategy return self._distribute_strategy
def get_shape(self): def get_shape(self):
return self._primary.get_shape() return self.primary.get_shape()
def to_proto(self, export_scope=None): def to_proto(self, export_scope=None):
return self._primary.to_proto(export_scope=export_scope) return self.primary.to_proto(export_scope=export_scope)
@property @property
def op(self): def op(self):
@ -508,13 +508,13 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
# to work (even if the current device isn't in self.devices), but # to work (even if the current device isn't in self.devices), but
# other uses of var.op in a cross-replica context to fail. # other uses of var.op in a cross-replica context to fail.
if distribution_strategy_context.in_cross_replica_context(): if distribution_strategy_context.in_cross_replica_context():
return DistributedVarOp(self._primary.op.name, self._primary.op.graph, return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
self._primary.op.traceback, self._primary.op.type) self.primary.op.traceback, self.primary.op.type)
return self._get().op return self._get().op
@property @property
def _in_graph_mode(self): def _in_graph_mode(self):
return self._primary._in_graph_mode # pylint: disable=protected-access return self.primary._in_graph_mode # pylint: disable=protected-access
def read_value(self): def read_value(self):
with _enter_or_assert_strategy(self._distribute_strategy): with _enter_or_assert_strategy(self._distribute_strategy):
@ -567,7 +567,7 @@ class TPUVariableMixin(object):
# Handle ID is needed for `get_replicated_var_handle` to cache the variables # Handle ID is needed for `get_replicated_var_handle` to cache the variables
# correctly since in eager mode different variables can have the same name. # correctly since in eager mode different variables can have the same name.
if ops.executing_eagerly_outside_functions(): if ops.executing_eagerly_outside_functions():
self._handle_id = self._common_name + "_" + str(id(self._primary)) self._handle_id = self._common_name + "_" + str(id(self.primary))
else: else:
self._handle_id = self._common_name self._handle_id = self._common_name
@ -592,7 +592,7 @@ class TPUVariableMixin(object):
if _enclosing_tpu_context() is None: if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._get_closest() return super(TPUVariableMixin, self)._get_closest()
else: else:
return self._primary return self.primary
def numpy(self): def numpy(self):
if context.executing_eagerly(): if context.executing_eagerly():
@ -644,8 +644,8 @@ class TPUVariableMixin(object):
@property @property
def op(self): def op(self):
return DistributedVarOp(self._primary.op.name, self._primary.op.graph, return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
self._primary.op.traceback, self._primary.op.type) self.primary.op.traceback, self.primary.op.type)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor.""" """Converts a variable to a tensor."""
@ -900,7 +900,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
""" """
def _saveable_factory(name=self._common_name): def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self._primary, name) return _MirroredSaveable(self, self.primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
@ -1003,8 +1003,7 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
slice_spec="", slice_spec="",
name=name, name=name,
dtype=sync_on_read_variable.dtype, dtype=sync_on_read_variable.dtype,
device=sync_on_read_variable._primary.device) # pylint: disable=protected-access device=sync_on_read_variable.primary.device)
super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name) super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes): def restore(self, restored_tensors, restored_shapes):
@ -1104,7 +1103,7 @@ class SyncOnReadVariable(DistributedVariable):
def _get_cross_replica(self): def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._primary return self.primary
with _enter_or_assert_strategy(self._distribute_strategy): with _enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce( return self._distribute_strategy.reduce(

View File

@ -274,7 +274,7 @@ class _SaveableView(object):
self.captured_tensor_node_ids[obj.resource_handle] = node_id self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif (ds_values.is_distributed_variable(obj) or elif (ds_values.is_distributed_variable(obj) or
resource_variable_ops.is_resource_variable(obj)): resource_variable_ops.is_resource_variable(obj)):
obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access obj_to_copy = obj.primary if ds_values.is_distributed_variable(
obj) else obj obj) else obj
new_variable = resource_variable_ops.copy_to_graph_uninitialized( new_variable = resource_variable_ops.copy_to_graph_uninitialized(
obj_to_copy) obj_to_copy)