PSv2: Revert _tensor naming to _value as it's not necessarily tensors.
PiperOrigin-RevId: 344154704 Change-Id: I6703542010e04c0c81200cb6a480eb62d2ec97ad
This commit is contained in:
parent
5467664c23
commit
5aff086d72
@ -176,7 +176,7 @@ class RemoteValueImpl(RemoteValue):
|
||||
"""
|
||||
self._closure = closure
|
||||
self._type_spec = type_spec
|
||||
self._tensors = None
|
||||
self._values = None
|
||||
self._fetched_numpys = None
|
||||
self._error = None
|
||||
self._status_available_event = threading.Event()
|
||||
@ -184,7 +184,7 @@ class RemoteValueImpl(RemoteValue):
|
||||
|
||||
def _set_aborted(self):
|
||||
self._status = _RemoteValueStatus.ABORTED
|
||||
self._tensors = None
|
||||
self._values = None
|
||||
self._error = None
|
||||
|
||||
# Wake up any waiting thread and clear the event.
|
||||
@ -195,21 +195,21 @@ class RemoteValueImpl(RemoteValue):
|
||||
# TODO(yuefengz): we may need to rebuild its inputs as well.
|
||||
self._closure.execute_on(worker)
|
||||
|
||||
def _set_tensors(self, tensors):
|
||||
def _set_values(self, tensors):
|
||||
self._status = _RemoteValueStatus.READY
|
||||
self._tensors = tensors
|
||||
self._values = tensors
|
||||
self._error = None
|
||||
self._status_available_event.set()
|
||||
|
||||
def _set_error(self, exception):
|
||||
self._status = _RemoteValueStatus.READY
|
||||
self._tensors = None
|
||||
self._values = None
|
||||
self._error = exception
|
||||
self._status_available_event.set()
|
||||
|
||||
def _get_tensors(self):
|
||||
def _get_values(self):
|
||||
self._status_available_event.wait()
|
||||
return self._tensors
|
||||
return self._values
|
||||
|
||||
def _get_error(self):
|
||||
self._status_available_event.wait()
|
||||
@ -226,7 +226,7 @@ class RemoteValueImpl(RemoteValue):
|
||||
raise self._error
|
||||
if self._fetched_numpys is None:
|
||||
self._fetched_numpys = nest.map_structure(
|
||||
lambda x: x.numpy() if hasattr(x, "numpy") else x, self._tensors)
|
||||
lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values)
|
||||
return self._fetched_numpys
|
||||
|
||||
|
||||
@ -273,7 +273,7 @@ def _maybe_get_remote_value(val):
|
||||
raise AssertionError(
|
||||
"RemoteValue doesn't have a value because it has errors.")
|
||||
else:
|
||||
return val._get_tensors() # pylint: disable=protected-access
|
||||
return val._get_values() # pylint: disable=protected-access
|
||||
else:
|
||||
return val
|
||||
|
||||
@ -408,10 +408,10 @@ class Closure(object):
|
||||
with ops.device(worker.device_name):
|
||||
with context.executor_scope(worker.executor):
|
||||
with metric_utils.monitored_timer("closure_execution"):
|
||||
output_tensors = self._function(
|
||||
output_values = self._function(
|
||||
*nest.map_structure(_maybe_get_remote_value, replica_args),
|
||||
**nest.map_structure(_maybe_get_remote_value, replica_kwargs))
|
||||
self.output_remote_value._set_tensors(output_tensors) # pylint: disable=protected-access
|
||||
self.output_remote_value._set_values(output_values) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _CoordinatedClosureQueue(object):
|
||||
|
Loading…
Reference in New Issue
Block a user