PSv2: Revert _tensor naming to _value as it's not necessarily tensors.

PiperOrigin-RevId: 344154704
Change-Id: I6703542010e04c0c81200cb6a480eb62d2ec97ad
This commit is contained in:
Rick Chao 2020-11-24 16:44:10 -08:00 committed by TensorFlower Gardener
parent 5467664c23
commit 5aff086d72

View File

@ -176,7 +176,7 @@ class RemoteValueImpl(RemoteValue):
"""
self._closure = closure
self._type_spec = type_spec
self._tensors = None
self._values = None
self._fetched_numpys = None
self._error = None
self._status_available_event = threading.Event()
@ -184,7 +184,7 @@ class RemoteValueImpl(RemoteValue):
def _set_aborted(self):
self._status = _RemoteValueStatus.ABORTED
self._tensors = None
self._values = None
self._error = None
# Wake up any waiting thread and clear the event.
@ -195,21 +195,21 @@ class RemoteValueImpl(RemoteValue):
# TODO(yuefengz): we may need to rebuild its inputs as well.
self._closure.execute_on(worker)
def _set_tensors(self, tensors):
def _set_values(self, tensors):
self._status = _RemoteValueStatus.READY
self._tensors = tensors
self._values = tensors
self._error = None
self._status_available_event.set()
def _set_error(self, exception):
self._status = _RemoteValueStatus.READY
self._tensors = None
self._values = None
self._error = exception
self._status_available_event.set()
def _get_tensors(self):
def _get_values(self):
self._status_available_event.wait()
return self._tensors
return self._values
def _get_error(self):
self._status_available_event.wait()
@ -226,7 +226,7 @@ class RemoteValueImpl(RemoteValue):
raise self._error
if self._fetched_numpys is None:
self._fetched_numpys = nest.map_structure(
lambda x: x.numpy() if hasattr(x, "numpy") else x, self._tensors)
lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values)
return self._fetched_numpys
@ -273,7 +273,7 @@ def _maybe_get_remote_value(val):
raise AssertionError(
"RemoteValue doesn't have a value because it has errors.")
else:
return val._get_tensors() # pylint: disable=protected-access
return val._get_values() # pylint: disable=protected-access
else:
return val
@ -408,10 +408,10 @@ class Closure(object):
with ops.device(worker.device_name):
with context.executor_scope(worker.executor):
with metric_utils.monitored_timer("closure_execution"):
output_tensors = self._function(
output_values = self._function(
*nest.map_structure(_maybe_get_remote_value, replica_args),
**nest.map_structure(_maybe_get_remote_value, replica_kwargs))
self.output_remote_value._set_tensors(output_tensors) # pylint: disable=protected-access
self.output_remote_value._set_values(output_values) # pylint: disable=protected-access
class _CoordinatedClosureQueue(object):