Special case tfnp ndarrays in some places, and minor improvements to tfnp array

functions

PiperOrigin-RevId: 317735476
Change-Id: I58123c5fb64257975e220df4e57f2a1f74856da7
This commit is contained in:
Scott Wegner 2020-06-22 14:26:34 -07:00 committed by TensorFlower Gardener
parent 7a6c2f69d4
commit 0c227aed65
4 changed files with 5 additions and 35 deletions

View File

@ -560,7 +560,6 @@ py_library(
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:util",
"//tensorflow/python/ops/numpy_ops:numpy",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -81,9 +81,6 @@ from tensorflow.python.util import tf_inspect
ag_ctx = lazy_loader.LazyLoader(
"ag_ctx", globals(),
"tensorflow.python.autograph.core.ag_ctx")
np_arrays = lazy_loader.LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
@ -1487,11 +1484,6 @@ class ConcreteFunction(object):
self._func_graph = func_graph
self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures
structured_outputs = self._func_graph.structured_outputs
self._ndarrays_list = (
isinstance(structured_outputs, (list, tuple)) and
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
# function_spec defines the structured signature.
self._set_function_spec(function_spec)
@ -2158,15 +2150,9 @@ class ConcreteFunction(object):
if self._func_graph.structured_outputs is None:
return result
if result:
if self._ndarrays_list:
return [np_arrays.tensor_to_ndarray(o) for o in result]
elif self._ndarray_singleton:
return np_arrays.tensor_to_ndarray(result[0])
# Replace outputs with results, skipping over any 'None' values.
outputs_list = nest.flatten(
self._func_graph.structured_outputs, expand_composites=True)
outputs_list = nest.flatten(self._func_graph.structured_outputs,
expand_composites=True)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:

View File

@ -149,23 +149,10 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
if dtype and dtype != buffer.dtype:
buffer = array_ops.bitcast(buffer, dtype)
self._data = buffer
self._type_spec_internal = None
@classmethod
def from_tensor(cls, tensor):
o = cls.__new__(cls, None)
# pylint: disable=protected-access
o._data = tensor
o._type_spec_internal = None
# pylint: enable=protected-access
return o
@property
def _type_spec(self):
if self._type_spec_internal is None:
self._type_spec_internal = NdarraySpec(
type_spec.type_spec_from_value(self._data))
return self._type_spec_internal
return NdarraySpec(type_spec.type_spec_from_value(self._data))
@property
def data(self):
@ -312,7 +299,7 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
def tensor_to_ndarray(tensor):
return ndarray.from_tensor(tensor)
return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):

View File

@ -64,9 +64,7 @@ class InteropTest(test.TestCase):
dx, dy = t.gradient([xx, yy], [x, y])
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
# self.assertIsInstance(dx, np_arrays.ndarray)
# self.assertIsInstance(dy, np_arrays.ndarray)
# TODO(nareshmodi): Gradient tape returns tensors. Is it possible to rewrap?
self.assertAllClose(dx, 2.0)
self.assertAllClose(dy, 3.0)