Rename _get_closest to more accurately reflect what it does.

PiperOrigin-RevId: 312155516
Change-Id: I27d8dd110ace0150ea735f718ed94948a9a75a74
This commit is contained in:
Anjali Sridhar 2020-05-18 14:33:45 -07:00 committed by TensorFlower Gardener
parent 3d4c5d1b57
commit 6dcb7268bb
2 changed files with 12 additions and 12 deletions

View File

@ -139,7 +139,7 @@ class DistributedValues(object):
"This method should be overridden by sub-classes which support cross-"
"replica accesses.")
def _get_closest(self):
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
replica_id = _get_current_replica_id_as_int()
if replica_id is None:
@ -379,7 +379,7 @@ class Mirrored(DistributedDelegate):
"""Holds a map from replica to values which are kept in sync."""
def _get_cross_replica(self):
return self._get_closest()
return self._get_on_device_or_primary()
def _as_graph_element(self):
obj = self._get()
@ -480,11 +480,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
return init_op
def initialized_value(self):
return self._get_closest().initialized_value()
return self._get_on_device_or_primary().initialized_value()
@property
def initial_value(self):
return self._get_closest().initial_value
return self._get_on_device_or_primary().initial_value
@property
def constraint(self):
@ -537,7 +537,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
return self._values[replica_id].handle
def eval(self, session=None):
return self._get_closest().eval(session)
return self._get_on_device_or_primary().eval(session)
@property
def _save_slice_info(self):
@ -552,7 +552,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
@property
def device(self):
return self._get_closest().device
return self._get_on_device_or_primary().device
@property
def trainable(self):
@ -587,7 +587,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
return array_ops.identity(self._get())
def value(self):
return self._get_closest().value()
return self._get_on_device_or_primary().value()
def numpy(self):
if context.executing_eagerly():
@ -961,7 +961,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
return array_ops.identity(Mirrored._get_cross_replica(self))
def _as_graph_element(self):
return self._get_closest()._as_graph_element() # pylint: disable=protected-access
return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
@ -1067,7 +1067,7 @@ class SyncOnReadVariable(DistributedVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def _update_replica(self, update_fn, value, **kwargs):
return update_fn(self._get_closest(), value, **kwargs)
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
# with MirroredVariable.
@ -1146,8 +1146,8 @@ class SyncOnReadVariable(DistributedVariable):
if ds_context.in_cross_replica_context():
return self._get_cross_replica()
else:
# _get_closest() returns a Variable.
return self._get_closest().value()
# _get_on_device_or_primary() returns a Variable.
return self._get_on_device_or_primary().value()
def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:

View File

@ -768,7 +768,7 @@ class Optimizer(
# pylint: enable=protected-access
mirrored_slot = named_slots.get(key, None)
if mirrored_slot is None: return None
return mirrored_slot._get_closest() # pylint: disable=protected-access
return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access
return named_slots.get(_var_key(var), None)