diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index e26b829fa85..70eb7add2b7 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -176,7 +176,7 @@ class RemoteValueImpl(RemoteValue): """ self._closure = closure self._type_spec = type_spec - self._tensors = None + self._values = None self._fetched_numpys = None self._error = None self._status_available_event = threading.Event() @@ -184,7 +184,7 @@ class RemoteValueImpl(RemoteValue): def _set_aborted(self): self._status = _RemoteValueStatus.ABORTED - self._tensors = None + self._values = None self._error = None # Wake up any waiting thread and clear the event. @@ -195,21 +195,21 @@ class RemoteValueImpl(RemoteValue): # TODO(yuefengz): we may need to rebuild its inputs as well. self._closure.execute_on(worker) - def _set_tensors(self, tensors): + def _set_values(self, tensors): self._status = _RemoteValueStatus.READY - self._tensors = tensors + self._values = tensors self._error = None self._status_available_event.set() def _set_error(self, exception): self._status = _RemoteValueStatus.READY - self._tensors = None + self._values = None self._error = exception self._status_available_event.set() - def _get_tensors(self): + def _get_values(self): self._status_available_event.wait() - return self._tensors + return self._values def _get_error(self): self._status_available_event.wait() @@ -226,7 +226,7 @@ class RemoteValueImpl(RemoteValue): raise self._error if self._fetched_numpys is None: self._fetched_numpys = nest.map_structure( - lambda x: x.numpy() if hasattr(x, "numpy") else x, self._tensors) + lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values) return self._fetched_numpys @@ -273,7 +273,7 @@ def _maybe_get_remote_value(val): raise AssertionError( "RemoteValue doesn't have a value because it has errors.") else: - return val._get_tensors() # pylint: disable=protected-access + return val._get_values() # pylint: disable=protected-access else: return val @@ -408,10 +408,10 @@ class Closure(object): with ops.device(worker.device_name): with context.executor_scope(worker.executor): with metric_utils.monitored_timer("closure_execution"): - output_tensors = self._function( + output_values = self._function( *nest.map_structure(_maybe_get_remote_value, replica_args), **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) - self.output_remote_value._set_tensors(output_tensors) # pylint: disable=protected-access + self.output_remote_value._set_values(output_values) # pylint: disable=protected-access class _CoordinatedClosureQueue(object):