Only propagate "is_mirrored_variable" tag for mirrored variables which values are guaranteed to be the same all the time.
PiperOrigin-RevId: 274015168
This commit is contained in:
parent
444b2aced6
commit
090c30918e
@ -856,6 +856,10 @@ class TPUVariableMixin(object):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"numpy() is only available when eager execution is enabled.")
|
"numpy() is only available when eager execution is enabled.")
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
@ -865,7 +869,8 @@ class TPUVariableMixin(object):
|
|||||||
else:
|
else:
|
||||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||||
self._values,
|
self._values,
|
||||||
self._device_map)
|
self._device_map,
|
||||||
|
self._is_mirrored())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@ -995,7 +1000,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`NONE` variable synchronization mode is not supported with `Mirrored` "
|
"`NONE` variable synchronization mode is not supported with `Mirrored` "
|
||||||
"distribution strategy. Please change the `synchronization` for "
|
"distribution strategy. Please change the `synchronization` for "
|
||||||
"variable: " + kwargs["name"])
|
"variable: " + str(kwargs["name"]))
|
||||||
elif synchronization == vs.VariableSynchronization.ON_READ:
|
elif synchronization == vs.VariableSynchronization.ON_READ:
|
||||||
is_sync_on_read = True
|
is_sync_on_read = True
|
||||||
elif synchronization in (
|
elif synchronization in (
|
||||||
@ -1226,6 +1231,13 @@ class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
|||||||
gen_resource_variable_ops.assign_variable_op)
|
gen_resource_variable_ops.assign_variable_op)
|
||||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
if self.aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
|
# TODO(b/142440743): Remove this check once ONLY_FIRST_REPLICA aggregation
|
||||||
|
# works as expected.
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||||
@ -1389,6 +1401,9 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
|||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
|
gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def regroup(device_map, values, wrap_class=PerReplica):
|
def regroup(device_map, values, wrap_class=PerReplica):
|
||||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
||||||
|
@ -254,7 +254,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
self._pivot = pivot
|
self._pivot = pivot
|
||||||
self._replicated_vars = {}
|
self._replicated_vars = {}
|
||||||
|
|
||||||
def get_replicated_var_handle(self, name, vars_, device_map=None):
|
def get_replicated_var_handle(self,
|
||||||
|
name,
|
||||||
|
vars_,
|
||||||
|
device_map=None,
|
||||||
|
is_mirrored=False):
|
||||||
"""Returns a variable handle for replicated TPU variable 'var'.
|
"""Returns a variable handle for replicated TPU variable 'var'.
|
||||||
|
|
||||||
This is a method used by an experimental replicated variable implementation
|
This is a method used by an experimental replicated variable implementation
|
||||||
@ -264,7 +268,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
name: The common name of the variable.
|
name: The common name of the variable.
|
||||||
vars_: The replicated TPU variables.
|
vars_: The replicated TPU variables.
|
||||||
device_map: The DeviceMap used to create the variables if it is a
|
device_map: The DeviceMap used to create the variables if it is a
|
||||||
TPUMirroredVariable.
|
TPUMirroredVariable.
|
||||||
|
is_mirrored: Whether the variables are mirrored, which guarantees the
|
||||||
|
values in each replica are always the same.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The handle of the TPU replicated input node.
|
The handle of the TPU replicated input node.
|
||||||
@ -302,7 +308,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
graph._set_control_flow_context(self.outer_context)
|
graph._set_control_flow_context(self.outer_context)
|
||||||
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
||||||
name=name + "/handle",
|
name=name + "/handle",
|
||||||
is_mirrored_variable=True)
|
is_mirrored_variable=is_mirrored)
|
||||||
graph._set_control_flow_context(saved_context)
|
graph._set_control_flow_context(saved_context)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self._replicated_vars[name] = handle
|
self._replicated_vars[name] = handle
|
||||||
|
Loading…
x
Reference in New Issue
Block a user