Special case tfnp ndarrays in some places, and minor improvements to tfnp array
functions PiperOrigin-RevId: 317735476 Change-Id: I58123c5fb64257975e220df4e57f2a1f74856da7
This commit is contained in:
parent
7a6c2f69d4
commit
0c227aed65
|
@ -560,7 +560,6 @@ py_library(
|
|||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
|
|
@ -81,9 +81,6 @@ from tensorflow.python.util import tf_inspect
|
|||
ag_ctx = lazy_loader.LazyLoader(
|
||||
"ag_ctx", globals(),
|
||||
"tensorflow.python.autograph.core.ag_ctx")
|
||||
np_arrays = lazy_loader.LazyLoader(
|
||||
"np_arrays", globals(),
|
||||
"tensorflow.python.ops.numpy_ops.np_arrays")
|
||||
|
||||
|
||||
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
|
||||
|
@ -1487,11 +1484,6 @@ class ConcreteFunction(object):
|
|||
self._func_graph = func_graph
|
||||
self._captured_inputs = self._func_graph.external_captures
|
||||
self._captured_closures = self._func_graph.deferred_external_captures
|
||||
structured_outputs = self._func_graph.structured_outputs
|
||||
self._ndarrays_list = (
|
||||
isinstance(structured_outputs, (list, tuple)) and
|
||||
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
|
||||
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
|
||||
|
||||
# function_spec defines the structured signature.
|
||||
self._set_function_spec(function_spec)
|
||||
|
@ -2158,15 +2150,9 @@ class ConcreteFunction(object):
|
|||
if self._func_graph.structured_outputs is None:
|
||||
return result
|
||||
|
||||
if result:
|
||||
if self._ndarrays_list:
|
||||
return [np_arrays.tensor_to_ndarray(o) for o in result]
|
||||
elif self._ndarray_singleton:
|
||||
return np_arrays.tensor_to_ndarray(result[0])
|
||||
|
||||
# Replace outputs with results, skipping over any 'None' values.
|
||||
outputs_list = nest.flatten(
|
||||
self._func_graph.structured_outputs, expand_composites=True)
|
||||
outputs_list = nest.flatten(self._func_graph.structured_outputs,
|
||||
expand_composites=True)
|
||||
j = 0
|
||||
for i, o in enumerate(outputs_list):
|
||||
if o is not None:
|
||||
|
|
|
@ -149,23 +149,10 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
|||
if dtype and dtype != buffer.dtype:
|
||||
buffer = array_ops.bitcast(buffer, dtype)
|
||||
self._data = buffer
|
||||
self._type_spec_internal = None
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, tensor):
|
||||
o = cls.__new__(cls, None)
|
||||
# pylint: disable=protected-access
|
||||
o._data = tensor
|
||||
o._type_spec_internal = None
|
||||
# pylint: enable=protected-access
|
||||
return o
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
if self._type_spec_internal is None:
|
||||
self._type_spec_internal = NdarraySpec(
|
||||
type_spec.type_spec_from_value(self._data))
|
||||
return self._type_spec_internal
|
||||
return NdarraySpec(type_spec.type_spec_from_value(self._data))
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
@ -312,7 +299,7 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
|||
|
||||
|
||||
def tensor_to_ndarray(tensor):
|
||||
return ndarray.from_tensor(tensor)
|
||||
return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
||||
|
|
|
@ -64,9 +64,7 @@ class InteropTest(test.TestCase):
|
|||
|
||||
dx, dy = t.gradient([xx, yy], [x, y])
|
||||
|
||||
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
|
||||
# self.assertIsInstance(dx, np_arrays.ndarray)
|
||||
# self.assertIsInstance(dy, np_arrays.ndarray)
|
||||
# TODO(nareshmodi): Gradient tape returns tensors. Is it possible to rewrap?
|
||||
self.assertAllClose(dx, 2.0)
|
||||
self.assertAllClose(dy, 3.0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue