Special case tfnp ndarrays in some places, and minor improvements to tfnp array
functions PiperOrigin-RevId: 317735476 Change-Id: I58123c5fb64257975e220df4e57f2a1f74856da7
This commit is contained in:
parent
7a6c2f69d4
commit
0c227aed65
|
@ -560,7 +560,6 @@ py_library(
|
||||||
"//tensorflow/python:graph_to_function_def",
|
"//tensorflow/python:graph_to_function_def",
|
||||||
"//tensorflow/python:pywrap_tf_session",
|
"//tensorflow/python:pywrap_tf_session",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/ops/numpy_ops:numpy",
|
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
|
|
@ -81,9 +81,6 @@ from tensorflow.python.util import tf_inspect
|
||||||
ag_ctx = lazy_loader.LazyLoader(
|
ag_ctx = lazy_loader.LazyLoader(
|
||||||
"ag_ctx", globals(),
|
"ag_ctx", globals(),
|
||||||
"tensorflow.python.autograph.core.ag_ctx")
|
"tensorflow.python.autograph.core.ag_ctx")
|
||||||
np_arrays = lazy_loader.LazyLoader(
|
|
||||||
"np_arrays", globals(),
|
|
||||||
"tensorflow.python.ops.numpy_ops.np_arrays")
|
|
||||||
|
|
||||||
|
|
||||||
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
|
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
|
||||||
|
@ -1487,11 +1484,6 @@ class ConcreteFunction(object):
|
||||||
self._func_graph = func_graph
|
self._func_graph = func_graph
|
||||||
self._captured_inputs = self._func_graph.external_captures
|
self._captured_inputs = self._func_graph.external_captures
|
||||||
self._captured_closures = self._func_graph.deferred_external_captures
|
self._captured_closures = self._func_graph.deferred_external_captures
|
||||||
structured_outputs = self._func_graph.structured_outputs
|
|
||||||
self._ndarrays_list = (
|
|
||||||
isinstance(structured_outputs, (list, tuple)) and
|
|
||||||
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
|
|
||||||
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
|
|
||||||
|
|
||||||
# function_spec defines the structured signature.
|
# function_spec defines the structured signature.
|
||||||
self._set_function_spec(function_spec)
|
self._set_function_spec(function_spec)
|
||||||
|
@ -2158,15 +2150,9 @@ class ConcreteFunction(object):
|
||||||
if self._func_graph.structured_outputs is None:
|
if self._func_graph.structured_outputs is None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if result:
|
|
||||||
if self._ndarrays_list:
|
|
||||||
return [np_arrays.tensor_to_ndarray(o) for o in result]
|
|
||||||
elif self._ndarray_singleton:
|
|
||||||
return np_arrays.tensor_to_ndarray(result[0])
|
|
||||||
|
|
||||||
# Replace outputs with results, skipping over any 'None' values.
|
# Replace outputs with results, skipping over any 'None' values.
|
||||||
outputs_list = nest.flatten(
|
outputs_list = nest.flatten(self._func_graph.structured_outputs,
|
||||||
self._func_graph.structured_outputs, expand_composites=True)
|
expand_composites=True)
|
||||||
j = 0
|
j = 0
|
||||||
for i, o in enumerate(outputs_list):
|
for i, o in enumerate(outputs_list):
|
||||||
if o is not None:
|
if o is not None:
|
||||||
|
|
|
@ -149,23 +149,10 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||||
if dtype and dtype != buffer.dtype:
|
if dtype and dtype != buffer.dtype:
|
||||||
buffer = array_ops.bitcast(buffer, dtype)
|
buffer = array_ops.bitcast(buffer, dtype)
|
||||||
self._data = buffer
|
self._data = buffer
|
||||||
self._type_spec_internal = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_tensor(cls, tensor):
|
|
||||||
o = cls.__new__(cls, None)
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
o._data = tensor
|
|
||||||
o._type_spec_internal = None
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
return o
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
if self._type_spec_internal is None:
|
return NdarraySpec(type_spec.type_spec_from_value(self._data))
|
||||||
self._type_spec_internal = NdarraySpec(
|
|
||||||
type_spec.type_spec_from_value(self._data))
|
|
||||||
return self._type_spec_internal
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
|
@ -312,7 +299,7 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_ndarray(tensor):
|
def tensor_to_ndarray(tensor):
|
||||||
return ndarray.from_tensor(tensor)
|
return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
||||||
|
|
|
@ -64,9 +64,7 @@ class InteropTest(test.TestCase):
|
||||||
|
|
||||||
dx, dy = t.gradient([xx, yy], [x, y])
|
dx, dy = t.gradient([xx, yy], [x, y])
|
||||||
|
|
||||||
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
|
# TODO(nareshmodi): Gradient tape returns tensors. Is it possible to rewrap?
|
||||||
# self.assertIsInstance(dx, np_arrays.ndarray)
|
|
||||||
# self.assertIsInstance(dy, np_arrays.ndarray)
|
|
||||||
self.assertAllClose(dx, 2.0)
|
self.assertAllClose(dx, 2.0)
|
||||||
self.assertAllClose(dy, 3.0)
|
self.assertAllClose(dy, 3.0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue