Only propagate "is_mirrored_variable" tag for mirrored variables which values are guaranteed to be the same all the time.
PiperOrigin-RevId: 274015168
This commit is contained in:
parent
444b2aced6
commit
090c30918e
@ -856,6 +856,10 @@ class TPUVariableMixin(object):
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def _is_mirrored(self):
|
||||
raise NotImplementedError(
|
||||
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
@ -865,7 +869,8 @@ class TPUVariableMixin(object):
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||
self._values,
|
||||
self._device_map)
|
||||
self._device_map,
|
||||
self._is_mirrored())
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@ -995,7 +1000,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
raise ValueError(
|
||||
"`NONE` variable synchronization mode is not supported with `Mirrored` "
|
||||
"distribution strategy. Please change the `synchronization` for "
|
||||
"variable: " + kwargs["name"])
|
||||
"variable: " + str(kwargs["name"]))
|
||||
elif synchronization == vs.VariableSynchronization.ON_READ:
|
||||
is_sync_on_read = True
|
||||
elif synchronization in (
|
||||
@ -1226,6 +1231,13 @@ class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
||||
gen_resource_variable_ops.assign_variable_op)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
if self.aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
# TODO(b/142440743): Remove this check once ONLY_FIRST_REPLICA aggregation
|
||||
# works as expected.
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||
@ -1389,6 +1401,9 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
return False
|
||||
|
||||
|
||||
def regroup(device_map, values, wrap_class=PerReplica):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
||||
|
@ -254,7 +254,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._pivot = pivot
|
||||
self._replicated_vars = {}
|
||||
|
||||
def get_replicated_var_handle(self, name, vars_, device_map=None):
|
||||
def get_replicated_var_handle(self,
|
||||
name,
|
||||
vars_,
|
||||
device_map=None,
|
||||
is_mirrored=False):
|
||||
"""Returns a variable handle for replicated TPU variable 'var'.
|
||||
|
||||
This is a method used by an experimental replicated variable implementation
|
||||
@ -264,7 +268,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
name: The common name of the variable.
|
||||
vars_: The replicated TPU variables.
|
||||
device_map: The DeviceMap used to create the variables if it is a
|
||||
TPUMirroredVariable.
|
||||
TPUMirroredVariable.
|
||||
is_mirrored: Whether the variables are mirrored, which guarantees the
|
||||
values in each replica are always the same.
|
||||
|
||||
Returns:
|
||||
The handle of the TPU replicated input node.
|
||||
@ -302,7 +308,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
graph._set_control_flow_context(self.outer_context)
|
||||
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
||||
name=name + "/handle",
|
||||
is_mirrored_variable=True)
|
||||
is_mirrored_variable=is_mirrored)
|
||||
graph._set_control_flow_context(saved_context)
|
||||
# pylint: enable=protected-access
|
||||
self._replicated_vars[name] = handle
|
||||
|
Loading…
x
Reference in New Issue
Block a user