From 6dcb7268bb28221134cd1151a730e89023d59623 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Mon, 18 May 2020 14:33:45 -0700 Subject: [PATCH] Rename `_get_closest` to more accurately reflect what it does. PiperOrigin-RevId: 312155516 Change-Id: I27d8dd110ace0150ea735f718ed94948a9a75a74 --- tensorflow/python/distribute/values.py | 22 +++++++++++----------- tensorflow/python/training/optimizer.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 444915aa123..84904f93104 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -139,7 +139,7 @@ class DistributedValues(object): "This method should be overridden by sub-classes which support cross-" "replica accesses.") - def _get_closest(self): + def _get_on_device_or_primary(self): """Returns value in same replica or device if possible, else the _primary.""" replica_id = _get_current_replica_id_as_int() if replica_id is None: @@ -379,7 +379,7 @@ class Mirrored(DistributedDelegate): """Holds a map from replica to values which are kept in sync.""" def _get_cross_replica(self): - return self._get_closest() + return self._get_on_device_or_primary() def _as_graph_element(self): obj = self._get() @@ -480,11 +480,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return init_op def initialized_value(self): - return self._get_closest().initialized_value() + return self._get_on_device_or_primary().initialized_value() @property def initial_value(self): - return self._get_closest().initial_value + return self._get_on_device_or_primary().initial_value @property def constraint(self): @@ -537,7 +537,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return self._values[replica_id].handle def eval(self, session=None): - return self._get_closest().eval(session) + return self._get_on_device_or_primary().eval(session) @property def _save_slice_info(self): @@ -552,7 +552,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, @property def device(self): - return self._get_closest().device + return self._get_on_device_or_primary().device @property def trainable(self): @@ -587,7 +587,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return array_ops.identity(self._get()) def value(self): - return self._get_closest().value() + return self._get_on_device_or_primary().value() def numpy(self): if context.executing_eagerly(): @@ -961,7 +961,7 @@ class MirroredVariable(DistributedVariable, Mirrored): return array_ops.identity(Mirrored._get_cross_replica(self)) def _as_graph_element(self): - return self._get_closest()._as_graph_element() # pylint: disable=protected-access + return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access def _gather_saveables_for_checkpoint(self): """Overrides Trackable method. @@ -1067,7 +1067,7 @@ class SyncOnReadVariable(DistributedVariable): """Holds a map from replica to variables whose values are reduced on save.""" def _update_replica(self, update_fn, value, **kwargs): - return update_fn(self._get_closest(), value, **kwargs) + return update_fn(self._get_on_device_or_primary(), value, **kwargs) # TODO(b/154017756): Make assign behaivor in cross replica context consistent # with MirroredVariable. @@ -1146,8 +1146,8 @@ class SyncOnReadVariable(DistributedVariable): if ds_context.in_cross_replica_context(): return self._get_cross_replica() else: - # _get_closest() returns a Variable. - return self._get_closest().value() + # _get_on_device_or_primary() returns a Variable. + return self._get_on_device_or_primary().value() def _get_cross_replica(self): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9732ea04f26..1fe8a8c729b 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -768,7 +768,7 @@ class Optimizer( # pylint: enable=protected-access mirrored_slot = named_slots.get(key, None) if mirrored_slot is None: return None - return mirrored_slot._get_closest() # pylint: disable=protected-access + return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access return named_slots.get(_var_key(var), None)