Rename _get_closest
to more accurately reflect what it does.
PiperOrigin-RevId: 312155516 Change-Id: I27d8dd110ace0150ea735f718ed94948a9a75a74
This commit is contained in:
parent
3d4c5d1b57
commit
6dcb7268bb
@ -139,7 +139,7 @@ class DistributedValues(object):
|
||||
"This method should be overridden by sub-classes which support cross-"
|
||||
"replica accesses.")
|
||||
|
||||
def _get_closest(self):
|
||||
def _get_on_device_or_primary(self):
|
||||
"""Returns value in same replica or device if possible, else the _primary."""
|
||||
replica_id = _get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
@ -379,7 +379,7 @@ class Mirrored(DistributedDelegate):
|
||||
"""Holds a map from replica to values which are kept in sync."""
|
||||
|
||||
def _get_cross_replica(self):
|
||||
return self._get_closest()
|
||||
return self._get_on_device_or_primary()
|
||||
|
||||
def _as_graph_element(self):
|
||||
obj = self._get()
|
||||
@ -480,11 +480,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
return init_op
|
||||
|
||||
def initialized_value(self):
|
||||
return self._get_closest().initialized_value()
|
||||
return self._get_on_device_or_primary().initialized_value()
|
||||
|
||||
@property
|
||||
def initial_value(self):
|
||||
return self._get_closest().initial_value
|
||||
return self._get_on_device_or_primary().initial_value
|
||||
|
||||
@property
|
||||
def constraint(self):
|
||||
@ -537,7 +537,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
return self._values[replica_id].handle
|
||||
|
||||
def eval(self, session=None):
|
||||
return self._get_closest().eval(session)
|
||||
return self._get_on_device_or_primary().eval(session)
|
||||
|
||||
@property
|
||||
def _save_slice_info(self):
|
||||
@ -552,7 +552,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._get_closest().device
|
||||
return self._get_on_device_or_primary().device
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
@ -587,7 +587,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
return array_ops.identity(self._get())
|
||||
|
||||
def value(self):
|
||||
return self._get_closest().value()
|
||||
return self._get_on_device_or_primary().value()
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
@ -961,7 +961,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
return array_ops.identity(Mirrored._get_cross_replica(self))
|
||||
|
||||
def _as_graph_element(self):
|
||||
return self._get_closest()._as_graph_element() # pylint: disable=protected-access
|
||||
return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides Trackable method.
|
||||
@ -1067,7 +1067,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def _update_replica(self, update_fn, value, **kwargs):
|
||||
return update_fn(self._get_closest(), value, **kwargs)
|
||||
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
|
||||
|
||||
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
|
||||
# with MirroredVariable.
|
||||
@ -1146,8 +1146,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return self._get_cross_replica()
|
||||
else:
|
||||
# _get_closest() returns a Variable.
|
||||
return self._get_closest().value()
|
||||
# _get_on_device_or_primary() returns a Variable.
|
||||
return self._get_on_device_or_primary().value()
|
||||
|
||||
def _get_cross_replica(self):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
|
@ -768,7 +768,7 @@ class Optimizer(
|
||||
# pylint: enable=protected-access
|
||||
mirrored_slot = named_slots.get(key, None)
|
||||
if mirrored_slot is None: return None
|
||||
return mirrored_slot._get_closest() # pylint: disable=protected-access
|
||||
return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access
|
||||
|
||||
return named_slots.get(_var_key(var), None)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user