diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 8f86508f333..745f2c16e73 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -856,6 +856,10 @@ class TPUVariableMixin(object): raise NotImplementedError( "numpy() is only available when eager execution is enabled.") + def _is_mirrored(self): + raise NotImplementedError( + "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") + @property def handle(self): # If we're in a tpu.rewrite(), return the replicated handle. @@ -865,7 +869,8 @@ class TPUVariableMixin(object): else: return tpu_context.get_replicated_var_handle(self._handle_id, self._values, - self._device_map) + self._device_map, + self._is_mirrored()) @property def device(self): @@ -995,7 +1000,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring raise ValueError( "`NONE` variable synchronization mode is not supported with `Mirrored` " "distribution strategy. Please change the `synchronization` for " - "variable: " + kwargs["name"]) + "variable: " + str(kwargs["name"])) elif synchronization == vs.VariableSynchronization.ON_READ: is_sync_on_read = True elif synchronization in ( @@ -1226,6 +1231,13 @@ class TPUMirroredVariable(TPUVariableMixin, MirroredVariable): gen_resource_variable_ops.assign_variable_op) return self._assign_func(f=assign_fn, *args, **kwargs) + def _is_mirrored(self): + if self.aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + # TODO(b/142440743): Remove this check once ONLY_FIRST_REPLICA aggregation + # works as expected. + return False + return True + class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a SyncOnReadVariable.""" @@ -1389,6 +1401,9 @@ class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable): return _make_raw_assign_fn( gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs) + def _is_mirrored(self): + return False + def regroup(device_map, values, wrap_class=PerReplica): """Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index d7f6085dee9..2ebdea13b46 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -254,7 +254,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._pivot = pivot self._replicated_vars = {} - def get_replicated_var_handle(self, name, vars_, device_map=None): + def get_replicated_var_handle(self, + name, + vars_, + device_map=None, + is_mirrored=False): """Returns a variable handle for replicated TPU variable 'var'. This is a method used by an experimental replicated variable implementation @@ -264,7 +268,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): name: The common name of the variable. vars_: The replicated TPU variables. device_map: The DeviceMap used to create the variables if it is a - TPUMirroredVariable. + TPUMirroredVariable. + is_mirrored: Whether the variables are mirrored, which guarantees the + values in each replica are always the same. Returns: The handle of the TPU replicated input node. @@ -302,7 +308,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): graph._set_control_flow_context(self.outer_context) handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars], name=name + "/handle", - is_mirrored_variable=True) + is_mirrored_variable=is_mirrored) graph._set_control_flow_context(saved_context) # pylint: enable=protected-access self._replicated_vars[name] = handle