Only propagate "is_mirrored_variable" tag for mirrored variables which values are guaranteed to be the same all the time.

PiperOrigin-RevId: 274015168
This commit is contained in:
Ruoxin Sang 2019-10-10 12:32:37 -07:00 committed by TensorFlower Gardener
parent 444b2aced6
commit 090c30918e
2 changed files with 26 additions and 5 deletions

View File

@ -856,6 +856,10 @@ class TPUVariableMixin(object):
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def _is_mirrored(self):
raise NotImplementedError(
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
@property
def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle.
@ -865,7 +869,8 @@ class TPUVariableMixin(object):
else:
return tpu_context.get_replicated_var_handle(self._handle_id,
self._values,
self._device_map)
self._device_map,
self._is_mirrored())
@property
def device(self):
@ -995,7 +1000,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
raise ValueError(
"`NONE` variable synchronization mode is not supported with `Mirrored` "
"distribution strategy. Please change the `synchronization` for "
"variable: " + kwargs["name"])
"variable: " + str(kwargs["name"]))
elif synchronization == vs.VariableSynchronization.ON_READ:
is_sync_on_read = True
elif synchronization in (
@ -1226,6 +1231,13 @@ class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
gen_resource_variable_ops.assign_variable_op)
return self._assign_func(f=assign_fn, *args, **kwargs)
def _is_mirrored(self):
if self.aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
# TODO(b/142440743): Remove this check once ONLY_FIRST_REPLICA aggregation
# works as expected.
return False
return True
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a SyncOnReadVariable."""
@ -1389,6 +1401,9 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
def _is_mirrored(self):
return False
def regroup(device_map, values, wrap_class=PerReplica):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""

View File

@ -254,7 +254,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
def get_replicated_var_handle(self, name, vars_, device_map=None):
def get_replicated_var_handle(self,
name,
vars_,
device_map=None,
is_mirrored=False):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
@ -264,7 +268,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
name: The common name of the variable.
vars_: The replicated TPU variables.
device_map: The DeviceMap used to create the variables if it is a
TPUMirroredVariable.
TPUMirroredVariable.
is_mirrored: Whether the variables are mirrored, which guarantees the
values in each replica are always the same.
Returns:
The handle of the TPU replicated input node.
@ -302,7 +308,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
name=name + "/handle",
is_mirrored_variable=True)
is_mirrored_variable=is_mirrored)
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
self._replicated_vars[name] = handle