Rename _get_closest
to more accurately reflect what it does.
PiperOrigin-RevId: 312155516 Change-Id: I27d8dd110ace0150ea735f718ed94948a9a75a74
This commit is contained in:
parent
3d4c5d1b57
commit
6dcb7268bb
@ -139,7 +139,7 @@ class DistributedValues(object):
|
|||||||
"This method should be overridden by sub-classes which support cross-"
|
"This method should be overridden by sub-classes which support cross-"
|
||||||
"replica accesses.")
|
"replica accesses.")
|
||||||
|
|
||||||
def _get_closest(self):
|
def _get_on_device_or_primary(self):
|
||||||
"""Returns value in same replica or device if possible, else the _primary."""
|
"""Returns value in same replica or device if possible, else the _primary."""
|
||||||
replica_id = _get_current_replica_id_as_int()
|
replica_id = _get_current_replica_id_as_int()
|
||||||
if replica_id is None:
|
if replica_id is None:
|
||||||
@ -379,7 +379,7 @@ class Mirrored(DistributedDelegate):
|
|||||||
"""Holds a map from replica to values which are kept in sync."""
|
"""Holds a map from replica to values which are kept in sync."""
|
||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
return self._get_closest()
|
return self._get_on_device_or_primary()
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
obj = self._get()
|
obj = self._get()
|
||||||
@ -480,11 +480,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
return init_op
|
return init_op
|
||||||
|
|
||||||
def initialized_value(self):
|
def initialized_value(self):
|
||||||
return self._get_closest().initialized_value()
|
return self._get_on_device_or_primary().initialized_value()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def initial_value(self):
|
def initial_value(self):
|
||||||
return self._get_closest().initial_value
|
return self._get_on_device_or_primary().initial_value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def constraint(self):
|
def constraint(self):
|
||||||
@ -537,7 +537,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
return self._values[replica_id].handle
|
return self._values[replica_id].handle
|
||||||
|
|
||||||
def eval(self, session=None):
|
def eval(self, session=None):
|
||||||
return self._get_closest().eval(session)
|
return self._get_on_device_or_primary().eval(session)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _save_slice_info(self):
|
def _save_slice_info(self):
|
||||||
@ -552,7 +552,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self._get_closest().device
|
return self._get_on_device_or_primary().device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable(self):
|
def trainable(self):
|
||||||
@ -587,7 +587,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
return array_ops.identity(self._get())
|
return array_ops.identity(self._get())
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._get_closest().value()
|
return self._get_on_device_or_primary().value()
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -961,7 +961,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
return array_ops.identity(Mirrored._get_cross_replica(self))
|
return array_ops.identity(Mirrored._get_cross_replica(self))
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
return self._get_closest()._as_graph_element() # pylint: disable=protected-access
|
return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
|
||||||
|
|
||||||
def _gather_saveables_for_checkpoint(self):
|
def _gather_saveables_for_checkpoint(self):
|
||||||
"""Overrides Trackable method.
|
"""Overrides Trackable method.
|
||||||
@ -1067,7 +1067,7 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||||
|
|
||||||
def _update_replica(self, update_fn, value, **kwargs):
|
def _update_replica(self, update_fn, value, **kwargs):
|
||||||
return update_fn(self._get_closest(), value, **kwargs)
|
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
|
||||||
|
|
||||||
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
|
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
|
||||||
# with MirroredVariable.
|
# with MirroredVariable.
|
||||||
@ -1146,8 +1146,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
if ds_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
return self._get_cross_replica()
|
return self._get_cross_replica()
|
||||||
else:
|
else:
|
||||||
# _get_closest() returns a Variable.
|
# _get_on_device_or_primary() returns a Variable.
|
||||||
return self._get_closest().value()
|
return self._get_on_device_or_primary().value()
|
||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
|
@ -768,7 +768,7 @@ class Optimizer(
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
mirrored_slot = named_slots.get(key, None)
|
mirrored_slot = named_slots.get(key, None)
|
||||||
if mirrored_slot is None: return None
|
if mirrored_slot is None: return None
|
||||||
return mirrored_slot._get_closest() # pylint: disable=protected-access
|
return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access
|
||||||
|
|
||||||
return named_slots.get(_var_key(var), None)
|
return named_slots.get(_var_key(var), None)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user